def initialize(self, opt, log): self.opt = opt self.gpu_ids = opt.gpu_ids self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor nb = opt.cycle_batchSize crop_height, crop_width = opt.crop_height, opt.crop_width self.input_A = self.Tensor(nb, 3, crop_height, crop_width) self.input_B = self.Tensor(nb, 3, crop_height, crop_width) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = define_G(gpu_ids=self.gpu_ids) self.netG_B = define_G(gpu_ids=self.gpu_ids) self.netD_A = define_D(gpu_ids=self.gpu_ids) self.netD_B = define_D(gpu_ids=self.gpu_ids) # for training self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = GANLoss(use_lsgan=True, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.cycle_lr, betas=(opt.cycle_beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.cycle_lr, betas=(opt.cycle_beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.cycle_lr, betas=(opt.cycle_beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(get_scheduler(optimizer, opt)) utils.print_log('------------ Networks initialized -------------', log) print_network(self.netG_A, 'netG_A', log) print_network(self.netG_B, 'netG_B', log) print_network(self.netD_A, 'netD_A', log) print_network(self.netD_B, 'netD_B', log) utils.print_log('-----------------------------------------------', log)
def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # define discriminators self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert(opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D)
def __init__(self, opt): super(MaskMobileCycleGANModel, self).__init__() self.opt = opt self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len( opt.gpu_ids) > 0 else 'cpu' self.loss_names = [ 'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'mask_weight' ] visual_names_A = ['real_A', 'fake_B', 'rec_A', 'idt_B'] visual_names_B = ['real_B', 'fake_A', 'rec_B', 'idt_A'] self.visual_names = visual_names_A + visual_names_B self.netG_A = MaskMobileResnetGenerator(opt=self.opt, ngf=self.opt.ngf) self.netG_B = MaskMobileResnetGenerator(opt=self.opt, ngf=self.opt.ngf) self.netD_A = NLayerDiscriminator(ndf=self.opt.ndf) self.netD_B = NLayerDiscriminator(ndf=self.opt.ndf) self.init_net() self.fake_A_pool = ImagePool(50) self.fake_B_pool = ImagePool(50) self.group_mask_weight_names = [] self.group_mask_weight_names.append('model.11') for i in range(13, 22, 1): self.group_mask_weight_names.append('model.%d.conv_block.9' % i) self.stop_AtoB_mask = False self.stop_BtoA_mask = False # define loss functions self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) self.criterionCycle = nn.L1Loss() self.criterionIdt = nn.L1Loss() # define optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(0.5, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(0.5, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.schedulers = [ util.get_scheduler(optimizer, opt) for optimizer in self.optimizers ]
def __init__(self, args): super().__init__(args) if args.mode == 'train': self.D = define_D(args) self.D = self.D.to(self.device) self.fake_right_pool = ImagePool(50) self.criterionMonoDepth = define_generator_loss(args) self.criterionMonoDepth = self.criterionMonoDepth.to(self.device) self.criterionGAN = define_discriminator_loss(args) self.criterionGAN = self.criterionGAN.to(self.device) # Load the correct networks, depending on which mode we are in. if args.mode == 'train': self.model_names = ['G', 'D'] self.optimizer_names = ['G', 'D'] else: self.model_names = ['G'] self.loss_names = ['G', 'D'] # We do Resume Training for this architecture. if args.resume == '': pass else: self.load_checkpoint(load_optim=False) if args.mode == 'train': # After resuming, set new optimizers. self.optimizer_G = optim.SGD(self.G.parameters(), lr=args.learning_rate) self.optimizer_D = optim.SGD(self.D.parameters(), lr=args.learning_rate) # Reset epoch. self.start_epoch = 0 self.trainG = True self.count_trained_G = 0 self.count_trained_D = 0 self.regime = args.resume_regime if 'cuda' in self.device: torch.cuda.synchronize()
def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single): super(train_style_translator_T, self).__init__(args) self._initialize_training() self.dataloaders_single = dataloaders_single self.dataloaders_xLabels_joint = dataloaders_xLabels_joint # define loss weights self.lambda_identity = 0.5 # coefficient of identity mapping score self.lambda_real = 10.0 self.lambda_synthetic = 10.0 self.lambda_GAN = 1.0 # define pool size in adversarial loss self.pool_size = 50 self.generated_syn_pool = ImagePool(self.pool_size) self.generated_real_pool = ImagePool(self.pool_size) self.netD_s = Discriminator80x80InstNorm(input_nc=3) self.netD_r = Discriminator80x80InstNorm(input_nc=3) self.netG_s2r = _ResGenerator_Upsample(input_nc=3, output_nc=3) self.netG_r2s = _ResGenerator_Upsample(input_nc=3, output_nc=3) self.model_name = ['netD_s', 'netD_r', 'netG_s2r', 'netG_r2s'] self.L1loss = nn.L1Loss() if self.isTrain: self.netD_optimizer = optim.Adam(list(self.netD_s.parameters()) + list(self.netD_r.parameters()), lr=self.D_lr, betas=(0.5, 0.999)) self.netG_optimizer = optim.Adam(list(self.netG_r2s.parameters()) + list(self.netG_s2r.parameters()), lr=self.G_lr, betas=(0.5, 0.999)) self.optim_name = ['netD_optimizer', 'netG_optimizer'] self._get_scheduler() self.loss_BCE = nn.BCEWithLogitsLoss() self._initialize_networks() # apex can only be applied to CUDA models if self.use_apex: self._init_apex(Num_losses=3) self._check_parallel()
def __init__(self, opt, G_A, G_B, D_A, D_B, optimizer_G, optimizer_D, summary_writer): self.opt = opt self.device = th.device('cuda:{}'.format( self.opt.gpu_ids[0])) if self.opt.gpu_ids else th.device('cpu') self.G_A = G_A self.G_B = G_B self.D_A = D_A self.D_B = D_B # define optimizer G and D self.optimizer_G = optimizer_G self.optimizer_D = optimizer_D self.criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = th.nn.L1Loss() self.criterionIdt = th.nn.L1Loss() self.summary_writer = summary_writer self.fake_B_pool = ImagePool(self.opt.pool_size) self.fake_A_pool = ImagePool(self.opt.pool_size)
def __init__(self, opt, cfg_AtoB=None, cfg_BtoA=None): super(MobileCycleGANModel, self).__init__() self.opt = opt self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(opt.gpu_ids) > 0 else 'cpu' self.cfg_AtoB = cfg_AtoB self.cfg_BtoA = cfg_BtoA self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] visual_names_A = ['real_A', 'fake_B', 'rec_A', 'idt_B'] visual_names_B = ['real_B', 'fake_A', 'rec_B', 'idt_A'] self.visual_names = visual_names_A + visual_names_B self.netG_A = MobileResnetGenerator(opt=self.opt, cfg=cfg_AtoB) self.netG_B = MobileResnetGenerator(opt=self.opt, cfg=cfg_BtoA) self.netD_A = NLayerDiscriminator() self.netD_B = NLayerDiscriminator() self.init_net() self.fake_A_pool = ImagePool(50) self.fake_B_pool = ImagePool(50) self.teacher_model = None if self.opt.lambda_attention_distill > 0: print('init attention distill') self.init_attention_distill() if self.opt.lambda_discriminator_distill > 0: print('init discriminator distill') self.init_discriminator_distill() # define loss functions self.criterionGAN= GANLoss(opt.gan_mode).to(self.device) self.criterionCycle = nn.L1Loss() self.criterionIdt = nn.L1Loss() # define optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(0.5, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(0.5, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.schedulers = [util.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
def initialize(self, opt): super(CrossModelV, self).initialize(opt) self.netG = GModel() self.netD = DModel() self.netG.initialize(opt) self.netD.initialize(opt) self.criterionGAN = GANLoss(opt.use_lsgan) self.optimizer_G = torch.optim.Adam( self.netG.parameters(), lr=opt.learn_rate, #betas=(.5, 0.9) betas=(.5, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.learn_rate, betas=(.5, 0.999)) self.pool = ImagePool(160) init_net(self) print(self)
def __init__(self, params): super(network, self).__init__() self.Tensor = torch.cuda.FloatTensor self.configurate(params['net']) self.fake_pool_x = ImagePool(self.pool_size) self.fake_pool_y = ImagePool(self.pool_size) self.input_x = self.Tensor(self.batch_size, 3, 256, 256) self.input_y = self.Tensor(self.batch_size, 3, 256, 256) self.target_x = self.Tensor(self.batch_size, 1, 256, 256) self.target_y = self.Tensor(self.batch_size, 1, 256, 256) self.tf_summary = Logger('./logs', self.name) self.enc_x = Encoder(**params['enc_x']).cuda() self.enc_y = Encoder(**params['enc_y']).cuda() self.mul_gen_x = Multitask_Generator(**params['gen_x']).cuda() self.mul_gen_y = Multitask_Generator(**params['gen_y']).cuda() self.dis_x = NLayerDiscriminator(**params['dis']).cuda() self.dis_y = NLayerDiscriminator(**params['dis']).cuda() self.criterionGAN = GANLoss() self.criterionCyC = torch.nn.L1Loss() self.criterionSeg = Segmentation_Loss() self.optimizer_G = torch.optim.Adam(itertools.chain( self.enc_x.parameters(), self.mul_gen_x.parameters(), self.enc_y.parameters(), self.mul_gen_y.parameters()), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.dis_x.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.dis_y.parameters(), lr=self.lr, betas=(0.5, 0.999))
def get_discriminator_input_fn(conf, disc_conf, no_pool=False): if disc_conf.get_attr('use_image_pool', default=False) and not no_pool: pool_size = disc_conf.get_attr('image_pool_size', default=5 * conf.batch_size) sample_prob = disc_conf.get_attr('image_pool_sample_prob', default=0.5) image_pool = ImagePool(pool_size, sample_prob) else: image_pool = None pool_label_swapping = disc_conf.get_attr('image_pool_label_swapping', default=False) input_method = disc_conf.get_attr('input_method', default=DEFAULT_INPUT_METHOD) normalize_input = disc_conf.get_attr('normalize_input', default=False) scale_input = disc_conf.get_attr('scale_input_zero_one', default=False) strip_bg_class = disc_conf.get_attr('strip_bg_class', default=False) cond_input_src = disc_conf.get_attr('conditional_input_source', default='input') if cond_input_src == 'input': cond_input_src = CondInputSource.INPUT elif cond_input_src == 'generator': cond_input_src = CondInputSource.OUT_GEN else: raise ValueError(('Unknown conditional ' 'input source {}').format(cond_input_src)) cond_input_gen_key = disc_conf.get_attr('conditional_input_generator_key') disc_input_fn = _build_input_fn(input_method, normalize_input, image_pool, cond_input_src, cond_input_gen_key, strip_bg_class, scale_input, pool_label_swapping) return disc_input_fn
class SSRGAN(BaseModel): def name(self): return 'SSRGAN' def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake): return [ l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, d_real, d_fake), flags) if f ] return loss_filter def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features input_nc = opt.input_nc self.para = opt.trade_off # define networks # Generator network netG_input_nc = input_nc self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan # netD_input_nc = input_nc + opt.output_nc netD_input_nc = opt.output_nc self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) # Encoder network if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError( "Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() # AWAN self.criterionCSS = networks.CSS() # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_CSS', 'D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: import sys if sys.version_info >= (3, 0): finetune_list = set() else: from sets import Set finetune_list = Set() params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [value] finetune_list.add(key.split('.')[0]) print( '------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) print('The layers that are finetuned are ', sorted(finetune_list)) else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) def encode_input(self, rgb, hyper, infer=False): # RGB for training if rgb is not None: rgb = Variable(rgb.data.cuda()) # hyper for training if hyper is not None: hyper = Variable(hyper.data.cuda()) return rgb, hyper def discriminate(self, rgb, hyper, use_pool=False): # input_concat = torch.cat((rgb, hyper.detach()), dim=1) input_concat = hyper.detach() if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward(self, rgb, hyper, infer=False): # Encode Inputs rgb, real_hyper = self.encode_input(rgb, hyper) # Fake Generation input_concat = rgb fake_hyper = self.netG.forward(input_concat) # Fake Detection and Loss pred_fake_pool = self.discriminate(rgb, fake_hyper, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(rgb, real_hyper) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) # pred_fake = self.netD.forward(torch.cat((rgb, fake_hyper), dim=1)) pred_fake = self.netD.forward(fake_hyper) loss_G_GAN = self.criterionGAN(pred_fake, True) lrm, lrm_rgb = self.criterionCSS(fake_hyper, real_hyper, rgb) loss_G_GAN += lrm + self.para * lrm_rgb # default 10 # GAN feature matching loss loss_G_GAN_Feat = 0 if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i]) - 1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss # loss_G_VGG = 0 loss_G_CSS = lrm + self.para * lrm_rgb # Only return the fake_B image if necessary to save BW # return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake), None if not infer else fake_hyper] return [ self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_CSS, loss_D_real, loss_D_fake), None if not infer else fake_hyper ] def inference(self, rgb, hyper, image=None): # Encode Inputs rgb, real_hyper = self.encode_input(Variable(rgb), Variable(hyper), infer=True) # Fake Generation input_concat = rgb with torch.no_grad(): fake_hyper = self.netG.forward(input_concat) return fake_hyper def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) features_clustered = np.load(cluster_path, encoding='latin1').item() # randomly sample from the feature clusters inst_np = inst.cpu().numpy().astype(int) feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) for i in np.unique(inst_np): label = i if i < 1000 else i // 1000 if label in features_clustered: feat = features_clustered[label] cluster_idx = np.random.randint(0, feat.shape[0]) idx = (inst == int(i)).nonzero() for k in range(self.opt.feat_num): feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k] if self.opt.data_type == 16: feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): image = Variable(image.cuda(), volatile=True) feat_num = self.opt.feat_num h, w = inst.size()[2], inst.size()[3] block_num = 32 feat_map = self.netE.forward(image, inst.cuda()) inst_np = inst.cpu().numpy().astype(int) feature = {} for i in range(self.opt.label_nc): feature[i] = np.zeros((0, feat_num + 1)) for i in np.unique(inst_np): label = i if i < 1000 else i // 1000 idx = (inst == int(i)).nonzero() num = idx.size()[0] idx = idx[num // 2, :] val = np.zeros((1, feat_num + 1)) for k in range(feat_num): val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] val[0, feat_num] = float(num) / (h * w // block_num) feature[label] = np.append(feature[label], val, axis=0) return feature def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) if self.opt.data_type == 16: return edge.half() else: return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.verbose: print( '------------ Now also finetuning global generator -----------' ) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr
class train_style_translator_T(base_model): def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single): super(train_style_translator_T, self).__init__(args) self._initialize_training() self.dataloaders_single = dataloaders_single self.dataloaders_xLabels_joint = dataloaders_xLabels_joint # define loss weights self.lambda_identity = 0.5 # coefficient of identity mapping score self.lambda_real = 10.0 self.lambda_synthetic = 10.0 self.lambda_GAN = 1.0 # define pool size in adversarial loss self.pool_size = 50 self.generated_syn_pool = ImagePool(self.pool_size) self.generated_real_pool = ImagePool(self.pool_size) self.netD_s = Discriminator80x80InstNorm(input_nc=3) self.netD_r = Discriminator80x80InstNorm(input_nc=3) self.netG_s2r = _ResGenerator_Upsample(input_nc=3, output_nc=3) self.netG_r2s = _ResGenerator_Upsample(input_nc=3, output_nc=3) self.model_name = ['netD_s', 'netD_r', 'netG_s2r', 'netG_r2s'] self.L1loss = nn.L1Loss() if self.isTrain: self.netD_optimizer = optim.Adam(list(self.netD_s.parameters()) + list(self.netD_r.parameters()), lr=self.D_lr, betas=(0.5, 0.999)) self.netG_optimizer = optim.Adam(list(self.netG_r2s.parameters()) + list(self.netG_s2r.parameters()), lr=self.G_lr, betas=(0.5, 0.999)) self.optim_name = ['netD_optimizer', 'netG_optimizer'] self._get_scheduler() self.loss_BCE = nn.BCEWithLogitsLoss() self._initialize_networks() # apex can only be applied to CUDA models if self.use_apex: self._init_apex(Num_losses=3) self._check_parallel() def _get_project_name(self): return 'train_style_translator_T' def _initialize_networks(self): for name in self.model_name: getattr(self, name).train().to(self.device) init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02) def compute_D_loss(self, real_sample, fake_sample, netD): loss = 0 syn_acc = 0 real_acc = 0 output = netD(fake_sample) label = torch.full((output.size()), self.syn_label, device=self.device) predSyn = (output > 0.5).to(self.device, dtype=torch.float32) total_num = torch.numel(output) syn_acc += (predSyn == label).type( torch.float32).sum().item() / total_num loss += self.loss_BCE(output, label) output = netD(real_sample) label = torch.full((output.size()), self.real_label, device=self.device) predReal = (output > 0.5).to(self.device, dtype=torch.float32) real_acc += (predReal == label).type( torch.float32).sum().item() / total_num loss += self.loss_BCE(output, label) return loss, syn_acc, real_acc def compute_G_loss(self, real_sample, synthetic_sample, r2s_rgb, s2r_rgb, reconstruct_real, reconstruct_syn): ''' real_sample: [batch_size, 4, 240, 320] real rgb synthetic_sample: [batch_size, 4, 240, 320] synthetic rgb r2s_rgb: netG_r2s(real) s2r_rgb: netG_s2r(synthetic) ''' loss = 0 # identity loss if applicable if self.lambda_identity > 0: idt_real = self.netG_s2r(real_sample)[-1] idt_synthetic = self.netG_r2s(synthetic_sample)[-1] idt_loss = (self.L1loss(idt_real, real_sample) * self.lambda_real + self.L1loss(idt_synthetic, synthetic_sample) * self.lambda_synthetic) * self.lambda_identity else: idt_loss = 0 # GAN loss real_pred = self.netD_r(s2r_rgb) real_label = torch.full(real_pred.size(), self.real_label, device=self.device) GAN_loss_real = self.loss_BCE(real_pred, real_label) syn_pred = self.netD_s(r2s_rgb) syn_label = torch.full(syn_pred.size(), self.real_label, device=self.device) GAN_loss_syn = self.loss_BCE(syn_pred, syn_label) GAN_loss = (GAN_loss_real + GAN_loss_syn) * self.lambda_GAN # cycle consistency loss rec_real_loss = self.L1loss(reconstruct_real, real_sample) * self.lambda_real rec_syn_loss = self.L1loss(reconstruct_syn, synthetic_sample) * self.lambda_synthetic rec_loss = rec_real_loss + rec_syn_loss loss += (idt_loss + GAN_loss + rec_loss) return loss, idt_loss, GAN_loss, rec_loss def train(self): phase = 'train' since = time.time() best_loss = float('inf') tensorboardX_iter_count = 0 for epoch in range(self.total_epoch_num): print('\nEpoch {}/{}'.format(epoch + 1, self.total_epoch_num)) print('-' * 10) fn = open(self.train_log, 'a') fn.write('\nEpoch {}/{}\n'.format(epoch + 1, self.total_epoch_num)) fn.write('--' * 5 + '\n') fn.close() iterCount = 0 for sample_dict in self.dataloaders_xLabels_joint: imageListReal, depthListReal = sample_dict['real'] imageListSyn, depthListSyn = sample_dict['syn'] imageListSyn = imageListSyn.to(self.device) depthListSyn = depthListSyn.to(self.device) imageListReal = imageListReal.to(self.device) depthListReal = depthListReal.to(self.device) with torch.set_grad_enabled(phase == 'train'): s2r_rgb = self.netG_s2r(imageListSyn)[-1] reconstruct_syn = self.netG_r2s(s2r_rgb)[-1] r2s_rgb = self.netG_r2s(imageListReal)[-1] reconstruct_real = self.netG_s2r(r2s_rgb)[-1] ############# update generator set_requires_grad([self.netD_r, self.netD_s], False) netG_loss = 0. self.netG_optimizer.zero_grad() netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss = self.compute_G_loss( imageListReal, imageListSyn, r2s_rgb, s2r_rgb, reconstruct_real, reconstruct_syn) if self.use_apex: with amp.scale_loss(netG_loss, self.netG_optimizer, loss_id=0) as netG_loss_scaled: netG_loss_scaled.backward() else: netG_loss.backward() self.netG_optimizer.step() ############# update discriminator set_requires_grad([self.netD_r, self.netD_s], True) self.netD_optimizer.zero_grad() r2s_rgb_pool = self.generated_syn_pool.query(r2s_rgb) netD_s_loss, netD_s_syn_acc, netD_s_real_acc = self.compute_D_loss( imageListSyn, r2s_rgb.detach(), self.netD_s) s2r_rgb_pool = self.generated_real_pool.query(s2r_rgb) netD_r_loss, netD_r_syn_acc, netD_r_real_acc = self.compute_D_loss( imageListReal, s2r_rgb.detach(), self.netD_r) netD_loss = netD_s_loss + netD_r_loss if self.use_apex: with amp.scale_loss(netD_loss, self.netD_optimizer, loss_id=1) as netD_loss_scaled: netD_loss_scaled.backward() else: netD_loss.backward() self.netD_optimizer.step() iterCount += 1 if self.use_tensorboardX: self.train_display_freq = len( self.dataloaders_xLabels_joint ) # feel free to adjust the display frequency nrow = imageListReal.size()[0] if tensorboardX_iter_count % self.train_display_freq == 0: s2r_rgb_concat = torch.cat( (imageListSyn, s2r_rgb, imageListReal, reconstruct_syn), dim=0) self.write_2_tensorboardX( self.train_SummaryWriter, s2r_rgb_concat, name='RGB: syn, s2r, real, reconstruct syn', mode='image', count=tensorboardX_iter_count, nrow=nrow) r2s_rgb_concat = torch.cat( (imageListReal, r2s_rgb, imageListSyn, reconstruct_real), dim=0) self.write_2_tensorboardX( self.train_SummaryWriter, r2s_rgb_concat, name='RGB: real, r2s, synthetic, reconstruct real', mode='image', count=tensorboardX_iter_count, nrow=nrow) loss_val_list = [netD_loss, netG_loss] loss_name_list = ['netD_loss', 'netG_loss'] self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count) tensorboardX_iter_count += 1 if iterCount % 20 == 0: loss_summary = '\t{}/{} netD: {:.7f}, netG: {:.7f}'.format( iterCount, len(self.dataloaders_xLabels_joint), netD_loss, netG_loss) G_loss_summary = '\t\tG loss summary: netG: {:.7f}, idt_loss: {:.7f}, GAN_loss: {:.7f}, rec_loss: {:.7f}'.format( netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss) print(loss_summary) print(G_loss_summary) fn = open(self.train_log, 'a') fn.write(loss_summary + '\n') fn.write(G_loss_summary + '\n') fn.close() if (epoch + 1) % self.save_steps == 0: self.save_models(['netG_r2s'], mode=epoch + 1, save_list=['styleTranslator']) # take step in optimizer for scheduler in self.scheduler_list: scheduler.step() for optim in self.optim_name: lr = getattr(self, optim).param_groups[0]['lr'] lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format( epoch + 1, self.total_epoch_num, optim, lr) print(lr_update) fn = open(self.train_log, 'a') fn.write(lr_update + '\n') fn.close() time_elapsed = time.time() - since print('\nTraining complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) fn = open(self.train_log, 'a') fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format( time_elapsed // 60, time_elapsed % 60)) fn.close() def evaluate(self, mode): pass
class ITN(): def __repr__(self): return ('{name})'.format(name=self.__class__.__name__, **self.__dict__)) def initialize(self, opt, log): self.opt = opt self.gpu_ids = opt.gpu_ids self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor nb = opt.cycle_batchSize crop_height, crop_width = opt.crop_height, opt.crop_width self.input_A = self.Tensor(nb, 3, crop_height, crop_width) self.input_B = self.Tensor(nb, 3, crop_height, crop_width) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = define_G(gpu_ids=self.gpu_ids) self.netG_B = define_G(gpu_ids=self.gpu_ids) self.netD_A = define_D(gpu_ids=self.gpu_ids) self.netD_B = define_D(gpu_ids=self.gpu_ids) # for training self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = GANLoss(use_lsgan=True, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.cycle_lr, betas=(opt.cycle_beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.cycle_lr, betas=(opt.cycle_beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.cycle_lr, betas=(opt.cycle_beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(get_scheduler(optimizer, opt)) utils.print_log('------------ Networks initialized -------------', log) print_network(self.netG_A, 'netG_A', log) print_network(self.netG_B, 'netG_B', log) print_network(self.netD_A, 'netD_A', log) print_network(self.netD_B, 'netD_B', log) utils.print_log('-----------------------------------------------', log) def set_mode(self, mode): if mode.lower() == 'train': self.netG_A.train() self.netG_B.train() self.netD_A.train() self.netD_B.train() self.criterionGAN.train() self.criterionCycle.train() self.criterionIdt.train() elif mode.lower() == 'eval': self.netG_A.eval() self.netG_B.eval() self.netD_A.eval() self.netD_B.eval() else: raise NameError('The wrong mode : {}'.format(mode)) def set_input(self, input): input_A = input['A'] input_B = input['B'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) def prepaer_input(self): self.real_A = torch.autograd.Variable(self.input_A) self.real_B = torch.autograd.Variable(self.input_B) def num_parameters(self): params = count_parameters_in_MB(self.netG_A) params += count_parameters_in_MB(self.netG_B) params += count_parameters_in_MB(self.netD_B) params += count_parameters_in_MB(self.netD_B) return params def num_flops(self): self.prepaer_input() flops1, params1 = get_model_infos(self.netG_A.model, None, self.real_A) fake_B = self.netG_A(self.real_A) flops2, params2 = get_model_infos(self.netD_A.model, None, fake_B) return flops1 + flops2 def test(self): self.real_A = torch.autograd.Variable(self.input_A, volatile=True) self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = torch.autograd.Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(self.fake_B) self.loss_G_A = self.criterionGAN(pred_fake, True) # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.prepaer_input() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): D_A = self.loss_D_A.item() G_A = self.loss_G_A.item() Cyc_A = self.loss_cycle_A.item() D_B = self.loss_D_B.item() G_B = self.loss_G_B.item() Cyc_B = self.loss_cycle_B.item() if self.opt.identity > 0.0: idt_A = self.loss_idt_A.item() idt_B = self.loss_idt_B.item() return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) def get_current_visuals(self, isTrain): real_A = tensor2im(self.real_A.data) rec_A = tensor2im(self.rec_A.data) fake_A = tensor2im(self.fake_A.data) real_B = tensor2im(self.real_B.data) rec_B = tensor2im(self.rec_B.data) fake_B = tensor2im(self.fake_B.data) if isTrain and self.opt.identity > 0.0: idt_A = tensor2im(self.idt_A.data) idt_B = tensor2im(self.idt_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) def save(self, save_dir, log): save_network(save_dir, 'G_A', self.netG_A, self.gpu_ids) save_network(save_dir, 'D_A', self.netD_A, self.gpu_ids) save_network(save_dir, 'G_B', self.netG_B, self.gpu_ids) save_network(save_dir, 'D_B', self.netD_B, self.gpu_ids) utils.print_log('save the model into {}'.format(save_dir), log) def load(self, save_dir, log): load_network(save_dir, 'G_A', self.netG_A) load_network(save_dir, 'D_A', self.netD_A) load_network(save_dir, 'G_B', self.netG_B) load_network(save_dir, 'D_B', self.netD_B) utils.print_log('load the model from {}'.format(save_dir), log) # update learning rate (called once every epoch) def update_learning_rate(self, log): for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] utils.print_log('learning rate = {:.7f}'.format(lr), log)
class pix2pixGAN(BaseModel): def name(self): return 'Pix2PixModel' @staticmethod def modify_commandline_options(): parser = two_domain_parser_options() return add_lambda_L1(parser) def __init__(self, args, logger): super().__init__(args, logger) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['loss_G', 'loss_D'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks self.model_names = ['G', 'D'] self.sample_names = ['fake_B', 'real_A', 'real_B'] # load/define networks self.G = networks.define_G(args.input_nc, args.output_nc, args.ngf, args.which_model_netG, args.norm, not args.no_dropout, args.init_type, args.init_gain, self.gpu_ids) if not 'continue_train' in args: use_sigmoid = args.no_lsgan self.D = networks.define_D(args.input_nc + args.output_nc, args.ndf, args.which_model_netD, args.n_layers_D, args.norm, use_sigmoid, args.init_type, args.init_gain, self.gpu_ids) self.fake_AB_pool = ImagePool(args.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not args.no_lsgan).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=args.g_lr, betas=(args.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=args.d_lr, betas=(args.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input, args): AtoB = self.args.which_direction == 'AtoB' self.real_A = input[args.A_label if AtoB else args.B_label].to(self.device) self.real_B = input[args.B_label if AtoB else args.A_label].to(self.device) def forward(self): self.fake_B = self.G(self.real_A) def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) pred_fake = self.D(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.D(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.D(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.args.lambda_L1 self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self, num_steps, overwite_gen): self.forward() # update D self.set_requires_grad(self.D, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() # update G self.set_requires_grad(self.D, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step()
def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) if self.isTrain: use_sigmoid = opt.no_lsgan and not opt.no_sigmoid self.netD_1A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, one_out=True, opt=opt) self.netD_1B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, one_out=True, opt=opt) self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, one_out=False, opt=opt) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, one_out=False, opt=opt) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_1A', which_epoch) self.load_network(self.netD_B, 'D_1B', which_epoch) self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_1A = torch.optim.Adam(self.netD_1A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_1B = torch.optim.Adam(self.netD_1B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) self.optimizers.append(self.optimizer_D_1A) self.optimizers.append(self.optimizer_D_1B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A, opt) if self.isTrain: networks.print_network(self.netD_A, opt) networks.print_network(self.netD_1A, opt) print('-----------------------------------------------')
class CycleGANModel(BaseModel): """ This class implements the CycleGAN model, for learning image-to-image translation without paired data. The model training requires '--dataset_mode unaligned' dataset. By default, it uses a '--netG inception_9blocks' InceptionNet generator, a '--netD basic' discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective ('--gan_mode lsgan'). CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses. A (source domain), B (target domain). Generators: G_A: A -> B; G_B: B -> A. Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A. Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper) Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper) Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper) Dropout is not used in the original CycleGAN paper. """ assert is_train parser = super(CycleGANModel, CycleGANModel).modify_commandline_options( parser, is_train) parser.add_argument('--restore_G_A_path', type=str, default=None, help='the path to restore the generator G_A') parser.add_argument('--restore_D_A_path', type=str, default=None, help='the path to restore the discriminator D_A') parser.add_argument('--restore_G_B_path', type=str, default=None, help='the path to restore the generator G_B') parser.add_argument('--restore_D_B_path', type=str, default=None, help='the path to restore the discriminator D_B') parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument( '--lambda_identity', type=float, default=0.5, help='use identity mapping. ' 'Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. ' 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1' ) parser.add_argument( '--real_stat_A_path', type=str, required=True, help= 'the path to load the ground-truth A images information to compute FID.' ) parser.add_argument( '--real_stat_B_path', type=str, required=True, help= 'the path to load the ground-truth B images information to compute FID.' ) parser.set_defaults(norm='instance', dataset_mode='unaligned', batch_size=1, ndf=64, gan_mode='lsgan', nepochs=100, nepochs_decay=100, save_epoch_freq=20) return parser def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert opt.isTrain assert opt.direction == 'AtoB' assert opt.dataset_mode == 'unaligned' BaseModel.__init__(self, opt) self.loss_names = [ 'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B', 'G_idt_B' ] visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.opt.lambda_identity > 0.0: visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) if opt.lambda_identity > 0.0: assert (opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader_AtoB = create_eval_dataloader(self.opt, direction='AtoB') self.eval_dataloader_BtoA = create_eval_dataloader(self.opt, direction='BtoA') block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() self.best_fid_A, self.best_fid_B = 1e9, 1e9 self.best_mIoU = -1e9 self.fids_A, self.fids_B = [], [] self.mIoUs = [] self.is_best_A = False self.is_best_B = False self.npz_A = np.load(opt.real_stat_A_path) self.npz_B = np.load(opt.real_stat_B_path) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap domain A and domain B. """ self.real_A = input['A'].to(self.device) self.real_B = input['B'].to(self.device) def set_single_input(self, input): self.real_A = input['A'].to(self.device) self.image_paths = input['A_paths'] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG_A(self.real_A) self.rec_A = self.netG_B(self.fake_B) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): """Calculate the loss for generators G_A and G_B""" lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B if lambda_idt > 0: self.idt_A = self.netG_A(self.real_B) self.loss_G_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt self.idt_B = self.netG_B(self.real_A) self.loss_G_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_G_idt_A = 0 self.loss_G_idt_B = 0 self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) self.loss_G_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A self.loss_G_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_G_cycle_A + self.loss_G_cycle_B + self.loss_G_idt_A + self.loss_G_idt_B self.loss_G.backward() def optimize_parameters(self, steps): """Calculate losses, gradients, and update network weights; called in every training iteration""" self.forward() self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step() def test_single_side(self, direction): generator = getattr(self, 'netG_%s' % direction[0]) with torch.no_grad(): self.fake_B = generator(self.real_A) def evaluate_model(self, step, save_image=False): ret = {} self.is_best_A = False self.is_best_B = False save_dir = os.path.join(self.opt.log_dir, 'eval', str(step)) os.makedirs(save_dir, exist_ok=True) self.netG_A.eval() self.netG_B.eval() for direction in ['AtoB', 'BtoA']: eval_dataloader = getattr(self, 'eval_dataloader_' + direction) fakes, names = [], [] cnt = 0 for i, data_i in enumerate(tqdm(eval_dataloader)): self.set_single_input(data_i) self.test_single_side(direction) fakes.append(self.fake_B.cpu()) for j in range(len(self.image_paths)): short_path = ntpath.basename(self.image_paths[j]) name = os.path.splitext(short_path)[0] names.append(name) if cnt < 10 or save_image: input_im = util.tensor2im(self.real_A[j]) fake_im = util.tensor2im(self.fake_B[j]) util.save_image(input_im, os.path.join(save_dir, direction, 'input', '%s.png' % name), create_dir=True) util.save_image(fake_im, os.path.join(save_dir, direction, 'fake', '%s.png' % name), create_dir=True) cnt += 1 suffix = direction[-1] fid = get_fid(fakes, self.inception_model, getattr(self, 'npz_%s' % direction[-1]), device=self.device, batch_size=self.opt.eval_batch_size) if fid < getattr(self, 'best_fid_%s' % suffix): setattr(self, 'is_best_%s' % direction[0], True) setattr(self, 'best_fid_%s' % suffix, fid) fids = getattr(self, 'fids_%s' % suffix) fids.append(fid) if len(fids) > 3: fids.pop(0) ret['metric/fid_%s' % suffix] = fid ret['metric/fid_%s-mean' % suffix] = sum(getattr(self, 'fids_%s' % suffix)) / len( getattr(self, 'fids_%s' % suffix)) ret['metric/fid_%s-best' % suffix] = getattr( self, 'best_fid_%s' % suffix) self.netG_A.train() self.netG_B.train() return ret
def __init__(self, opt): # Initialize the Models # Global Variables self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) self.device = torch.device( f'cuda:{self.gpu_ids[0]}') if self.gpu_ids else torch.device('cpu') self.metric = 0 # used for learning rate policy 'plateau' self.G_AtoB = build_G(input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.ngf, norm=opt.norm, padding_type=opt.padding_type, use_dropout=not opt.no_dropout, n_blocks=opt.n_blocks_G, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=opt.gpu_ids) self.G_BtoA = build_G(input_nc=opt.output_nc, output_nc=opt.input_nc, ngf=opt.ngf, norm=opt.norm, padding_type=opt.padding_type, use_dropout=not opt.no_dropout, n_blocks=opt.n_blocks_G, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=opt.gpu_ids) self.net_names = ['G_AtoB', 'G_BtoA'] if self.isTrain: self.D_A = build_D(input_nc=opt.output_nc, ndf=opt.ndf, n_layers=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=opt.gpu_ids) self.D_B = build_D(input_nc=opt.input_nc, ndf=opt.ndf, n_layers=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=opt.gpu_ids) self.net_names.append('D_A') self.net_names.append('D_B') # create image buffer to store previously generated images self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = GANLoss(opt.gan_mode).to( self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizers = [] self.optimizer_G = torch.optim.Adam(itertools.chain( self.G_AtoB.parameters(), self.G_BtoA.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.D_A.parameters(), self.D_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) # lr Scheduler self.schedulers = [ get_scheduler(optimizer, lr_policy=opt.lr_policy, n_epochs=opt.n_epochs, lr_decay_iters=opt.lr_decay_iters, epoch_count=opt.epoch_count, n_epochs_decay=opt.n_epochs_decay) for optimizer in self.optimizers ] # Internal Variables self.real_A = None self.real_B = None self.image_paths = None self.fake_A = None self.fake_B = None self.rec_A = None self.rec_B = None self.idt_A = None self.idt_B = None self.loss_idt_A = None self.loss_idt_B = None self.loss_G_AtoB = None self.loss_G_BtoA = None self.cycle_loss_A = None self.cycle_loss_B = None self.loss_G = None self.loss_D_A = None self.loss_D_B = None # Printing the Networks for net_name in self.net_names: print(net_name, "\n", getattr(self, net_name)) # Continue training, if isTrain if self.isTrain: if self.opt.ct > 0: print(f"Continue training from {self.opt.ct}") self.load_train_model(str(self.opt.ct))
class VanillaGanSingleArchitecture(BaseArchitecture): def __init__(self, args): super().__init__(args) if args.mode == 'train': self.D = define_D(args) self.D = self.D.to(self.device) self.fake_right_pool = ImagePool(50) self.criterion = define_generator_loss(args) self.criterion = self.criterion.to(self.device) self.criterionGAN = define_discriminator_loss(args) self.criterionGAN = self.criterionGAN.to(self.device) self.optimizer_G = optim.Adam(self.G.parameters(), lr=args.learning_rate) self.optimizer_D = optim.SGD(self.D.parameters(), lr=args.learning_rate) # Load the correct networks, depending on which mode we are in. if args.mode == 'train': self.model_names = ['G', 'D'] self.optimizer_names = ['G', 'D'] else: self.model_names = ['G'] self.loss_names = ['G', 'G_MonoDepth', 'G_GAN', 'D'] self.losses = {} if self.args.resume: self.load_checkpoint() if 'cuda' in self.device: torch.cuda.synchronize() def set_input(self, data): self.data = to_device(data, self.device) self.left = self.data['left_image'] self.right = self.data['right_image'] def forward(self): self.disps = self.G(self.left) # Prepare disparities disp_right_est = [d[:, 1, :, :].unsqueeze(1) for d in self.disps] self.disp_right_est = disp_right_est[0] self.fake_right = self.criterion.generate_image_right( self.left, self.disp_right_est) def backward_D(self): # Fake fake_pool = self.fake_right_pool.query(self.fake_right) pred_fake = self.D(fake_pool.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real pred_real = self.D(self.right) self.loss_D_real = self.criterionGAN(pred_real, True) self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # G should fake D pred_fake = self.D(self.fake_right) self.loss_G_GAN = self.criterionGAN(pred_fake, True) self.loss_G_MonoDepth = self.criterion(self.disps, [self.left, self.right]) self.loss_G = self.loss_G_GAN * self.args.discriminator_w + self.loss_G_MonoDepth self.loss_G.backward() def optimize_parameters(self): self.forward() # Update D. self.set_requires_grad(self.D, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() # Update G. self.set_requires_grad(self.D, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def update_learning_rate(self, epoch, learning_rate): """ Sets the learning rate to the initial LR decayed by 2 every 10 epochs after 30 epochs. """ if self.args.adjust_lr: if 30 <= epoch < 40: lr = learning_rate / 2 elif epoch >= 40: lr = learning_rate / 4 else: lr = learning_rate for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr def get_untrained_loss(self): # -- Generator loss_G_MonoDepth = self.criterion(self.disps, [self.left, self.right]) fake_G_right = self.D(self.fake_right) loss_G_GAN = self.criterionGAN(fake_G_right, True) loss_G = loss_G_GAN * self.args.discriminator_w + loss_G_MonoDepth # -- Discriminator loss_D_fake = self.criterionGAN(self.D(self.fake_right), False) loss_D_real = self.criterionGAN(self.D(self.right), True) loss_D = (loss_D_fake + loss_D_real) * 0.5 return { 'G': loss_G.item(), 'G_MonoDepth': loss_G_MonoDepth.item(), 'G_GAN': loss_G_GAN.item(), 'D': loss_D.item() } @property def architecture(self): return 'Single GAN Architecture'
class CrossModel(BaseModel): def __init__(self): super(CrossModel, self).__init__() self.model_names = 'cross_model' @staticmethod def modify_commandline_options(parser, is_train=True): if is_train: parser.add_argument('--style_dropout', type=float, default=.5, help='dropout ratio of style feature vector') parser.add_argument('--style_channels', type=int, default=32, help='size of style channels') parser.add_argument( '--pool_size', type=int, default=150, help= 'size of image pool, which is used to prevent model collapse') parser.add_argument('--lambda_E', type=float, default=0.0, help='lambda of extra loss') parser.add_argument('--fast_forward', type=bool, default=False, help='do not train the selector') parser.add_argument('--opt_betas1', type=float, default=.5) parser.add_argument('--opt_betas2', type=float, default=.999) parser.add_argument('--g_model_transnet', type=str, default='resnet') parser.add_argument('--g_model_transnet_n_blocks', type=int, default=8) parser.add_argument('--d_model_n_blocks', type=int, default=1) parser.add_argument('--d_model_use_dropout', type=bool, default=False) parser.add_argument('--selector_criterion_method', type=str, default='l1') return parser def init_vistool(self, opt): self.vistool = vistool.VisTool(env=opt.name + '_model') self.vistool.register_data('fake_imgs', 'images') self.vistool.register_data('styles', 'images') self.vistool.register_data('texts', 'images') self.vistool.register_data('diff_with_average', 'images') self.vistool.register_data('gmodel_sorted', 'images') self.vistool.register_data('dmodel_sorted', 'images') self.vistool.register_data('scores', 'array') self.vistool.register_data('dis_preds_L1_loss', 'scalar_ma') self.vistool.register_data('sel_preds_L1_loss', 'scalar_ma') self.vistool.register_data('rad_preds_L1_loss', 'scalar_ma') self.vistool.register_data('mod_preds_L1_loss', 'scalar_ma') self.vistool.register_window('dmodel_sorted', 'images', source='dmodel_sorted') self.vistool.register_window('gmodel_sorted', 'images', source='gmodel_sorted') if not opt.fast_forward: self.vistool.register_window('scores', 'bar', source='scores') self.vistool.register_window('preds_L1_loss', 'lines', sources=[ 'dis_preds_L1_loss', 'sel_preds_L1_loss', 'rad_preds_L1_loss', 'mod_preds_L1_loss' ]) def initialize(self, opt): super(CrossModel, self).initialize(opt) self.fastForward = opt.fast_forward self.netG = GModel() self.netD = DModel() self.netG.initialize(opt) self.netD.initialize(opt) self.criterionGAN = GANLoss(False) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.learn_rate, betas=(opt.opt_betas1, opt.opt_betas2)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.learn_rate, betas=(opt.opt_betas1, opt.opt_betas2)) self.pool = ImagePool(opt.pool_size) self.lambda_E = opt.lambda_E self.criterionSelector = find_criterion_using_name( opt.selector_criterion_method)() init_net(self) path = opt.checkpoints_dir + '/' + self.model_names + '.txt' with open(path, 'w') as f: f.write(str(self)) logger.info("Model Structure has been written into %s" % path) self.init_vistool(opt) def set_input(self, texts, styles, target): self.texts = texts self.styles = styles self.real_img = target.unsqueeze(1) def forward(self): self.netG(self.texts, self.styles) self.fake_imgs = self.netG.basic_preds def backward_D(self): fake_all = self.fake_imgs real_all = self.real_img texts = self.texts styles = self.styles #A trick to prevent mode collapse img = torch.cat((fake_all, real_all, texts, styles), 1).detach() img = self.pool.query(img) tot = (img.size(1) - 1) // 3 fake_all, real_all, texts, styles = torch.split( img, [tot, 1, tot, tot], 1) fake_all = fake_all.contiguous() real_all = real_all.contiguous() pred_fake = self.netD(fake_all.detach(), texts, styles) pred_real = self.netD(real_all.detach(), texts, styles) self.loss_fake = self.criterionGAN(pred_fake, False) self.loss_real = self.criterionGAN(pred_real, True) self.loss_D = (self.loss_fake + self.loss_real) * .5 self.loss_D.backward() def backward_G(self): fake_all = self.fake_imgs pred_fake = self.netD(fake_all, self.texts, self.styles) self.loss_G = self.criterionGAN(pred_fake, True) #Gan loss self.loss_GSE = self.loss_G if not self.fastForward: pred_result = pred_fake.detach() self.loss_S = (pred_result - self.netG.basic_score).abs().mean() #Selector loss self.loss_GSE += self.loss_S self.vistool.update( 'scores', torch.stack((pred_result[0], self.netG.basic_score[0]), 1)) self.loss_E = self.netG.extra_loss # Extra loss self.loss_GSE += self.loss_E * self.lambda_E self.loss_GSE.backward() def optimize_parameters(self): self.forward() self.set_requires_grad(self.netD, True) self.optimizer_D.zero_grad() self.backward_D() if self.optm_d: self.optimizer_D.step() self.set_requires_grad(self.netD, False) self.forward() self.optimizer_G.zero_grad() self.backward_G() if self.optm_g: self.optimizer_G.step() bs, tot, W, H = self.texts.shape score = self.netG.basic_score + self.netD.basic_score * .5 rank = torch.sort(score, 1, descending=True)[1] model_preds = torch.gather( self.netG.basic_preds, 1, rank.view(bs, tot, 1, 1).expand(bs, tot, W, H)) self.vistool.update('gmodel_sorted', self.netG.best_preds[0] * .5 + .5) self.vistool.update('dmodel_sorted', self.netD.dis_preds[0] * .5 + .5) self.vistool.update('diff_with_average', self.netG.diff_with_average) self.vistool.update( 'mod_preds_L1_loss', self.criterionSelector(model_preds[:, 0, :, :], self.real_img[:, 0, :, :]).mean()) self.vistool.update( 'dis_preds_L1_loss', self.criterionSelector(self.netD.dis_preds[:, 0, :, :], self.real_img[:, 0, :, :]).mean()) self.vistool.update( 'sel_preds_L1_loss', self.criterionSelector(self.netG.best_preds[:, 0, :, :], self.real_img[:, 0, :, :]).mean()) idx = random.randint(0, self.netG.best_preds.size(1) - 1) self.vistool.update( 'rad_preds_L1_loss', self.criterionSelector(self.netG.best_preds[:, idx, :, :], self.real_img[:, 0, :, :]).mean()) self.vistool.update('fake_imgs', self.fake_imgs[0] * .5 + .5) self.vistool.update('styles', self.styles[0] * .5 + .5) self.vistool.update('texts', self.texts[0] * .5 + .5) self.vistool.sync()
def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features input_nc = opt.input_nc self.para = opt.trade_off # define networks # Generator network netG_input_nc = input_nc self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan # netD_input_nc = input_nc + opt.output_nc netD_input_nc = opt.output_nc self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) # Encoder network if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) if self.opt.verbose: print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError( "Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() # AWAN self.criterionCSS = networks.CSS() # Names so we can breakout loss self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_CSS', 'D_real', 'D_fake') # initialize optimizers # optimizer G if opt.niter_fix_global > 0: import sys if sys.version_info >= (3, 0): finetune_list = set() else: from sets import Set finetune_list = Set() params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [value] finetune_list.add(key.split('.')[0]) print( '------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) print('The layers that are finetuned are ', sorted(finetune_list)) else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert opt.isTrain assert opt.direction == 'AtoB' assert opt.dataset_mode == 'unaligned' BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = [ 'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B', 'G_idt_B' ] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert (opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images self.fake_B_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images # define loss functions self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to( self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader_AtoB = create_eval_dataloader(self.opt, direction='AtoB') self.eval_dataloader_BtoA = create_eval_dataloader(self.opt, direction='BtoA') block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids) self.drn_model.eval() self.best_fid_A, self.best_fid_B = 1e9, 1e9 self.best_mIoU = -1e9 self.fids_A, self.fids_B = [], [] self.mIoUs = [] self.is_best = False self.npz_A = np.load(opt.real_stat_A_path) self.npz_B = np.load(opt.real_stat_B_path)
class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query( torch.cat((self.real_A, self.fake_B), 1)) self.pred_fake = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) self.pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0])]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids)
class CycleGan: def __init__(self, opt): # Initialize the Models # Global Variables self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) self.device = torch.device( f'cuda:{self.gpu_ids[0]}') if self.gpu_ids else torch.device('cpu') self.metric = 0 # used for learning rate policy 'plateau' self.G_AtoB = build_G(input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.ngf, norm=opt.norm, padding_type=opt.padding_type, use_dropout=not opt.no_dropout, n_blocks=opt.n_blocks_G, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=opt.gpu_ids) self.G_BtoA = build_G(input_nc=opt.output_nc, output_nc=opt.input_nc, ngf=opt.ngf, norm=opt.norm, padding_type=opt.padding_type, use_dropout=not opt.no_dropout, n_blocks=opt.n_blocks_G, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=opt.gpu_ids) self.net_names = ['G_AtoB', 'G_BtoA'] if self.isTrain: self.D_A = build_D(input_nc=opt.output_nc, ndf=opt.ndf, n_layers=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=opt.gpu_ids) self.D_B = build_D(input_nc=opt.input_nc, ndf=opt.ndf, n_layers=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=opt.gpu_ids) self.net_names.append('D_A') self.net_names.append('D_B') # create image buffer to store previously generated images self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = GANLoss(opt.gan_mode).to( self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizers = [] self.optimizer_G = torch.optim.Adam(itertools.chain( self.G_AtoB.parameters(), self.G_BtoA.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.D_A.parameters(), self.D_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) # lr Scheduler self.schedulers = [ get_scheduler(optimizer, lr_policy=opt.lr_policy, n_epochs=opt.n_epochs, lr_decay_iters=opt.lr_decay_iters, epoch_count=opt.epoch_count, n_epochs_decay=opt.n_epochs_decay) for optimizer in self.optimizers ] # Internal Variables self.real_A = None self.real_B = None self.image_paths = None self.fake_A = None self.fake_B = None self.rec_A = None self.rec_B = None self.idt_A = None self.idt_B = None self.loss_idt_A = None self.loss_idt_B = None self.loss_G_AtoB = None self.loss_G_BtoA = None self.cycle_loss_A = None self.cycle_loss_B = None self.loss_G = None self.loss_D_A = None self.loss_D_B = None # Printing the Networks for net_name in self.net_names: print(net_name, "\n", getattr(self, net_name)) # Continue training, if isTrain if self.isTrain: if self.opt.ct > 0: print(f"Continue training from {self.opt.ct}") self.load_train_model(str(self.opt.ct)) def update_learning_rate(self): """Update learning rates for all the networks; called at the end of every epoch""" old_lr = self.optimizers[0].param_groups[0]['lr'] for scheduler in self.schedulers: if self.opt.lr_policy == 'plateau': scheduler.step(self.metric) else: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate %.7f -> %.7f' % (old_lr, lr)) def feed_input(self, x): """Unpack input data from the dataloader and perform necessary pre-processing steps. :type x: dict :param x: include the data itself and its metadata information. x should have the structure {'A': Tensor Images, 'B': Tensor Images, 'A_paths': paths of the A Images, 'B_paths': paths of the B Images} The option 'direction' can be used to swap domain A and domain B. """ AtoB = self.opt.direction == 'AtoB' self.real_A = x['A' if AtoB else 'B'].to(self.device) self.real_B = x['B' if AtoB else 'A'].to(self.device) self.image_paths = x['A_paths' if AtoB else 'B_paths'] def optimize_parameters(self): # Forward self.forward() # Train Generators self._set_requires_grad( [self.D_A, self.D_B], False) # Ds require no gradients when optimizing Gs self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights # Train Discriminators self._set_requires_grad([self.D_A, self.D_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step() def forward(self): """Run forward pass Called by both functions <optimize_parameters> and <test> """ self.fake_B = self.G_AtoB(self.real_A) # G_A(A) self.rec_A = self.G_BtoA(self.fake_B) # G_B(G_A(A)) self.fake_A = self.G_BtoA(self.real_B) # G_B(B) self.rec_B = self.G_AtoB(self.fake_A) # G_A(G_B(B)) def backward_G(self): lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # GAN loss D_A(G_AtoB(A)) self.loss_G_AtoB = self.criterionGAN(self.D_A(self.fake_B), True) # GAN loss D_B(G_BtoA(B)) self.loss_G_BtoA = self.criterionGAN(self.D_B(self.fake_A), True) # Forward cycle loss || G_B(G_A(A)) - A|| self.cycle_loss_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| self.cycle_loss_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_AtoB + self.loss_G_BtoA + self.cycle_loss_A + self.cycle_loss_B self.loss_G += self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator :param netD: the discriminator D :param real: real images :param fake: images generated by a generator :return: Loss """ # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.D_A, self.real_B, fake_B) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.D_B, self.real_A, fake_A) def _set_requires_grad(self, nets: List[nn.Module], requires_grad: bool = False) -> None: """ Set requires_grad=False for all the networks to avoid unnecessary computations :param nets: a list of networks :param requires_grad: whether the networks require gradients or not """ for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad def train(self): """Make models train mode during test time""" self.G_AtoB.train() self.G_BtoA.train() if self.isTrain: self.D_A.train() self.D_B.train() def eval(self): """Make models eval mode during test time""" self.G_AtoB.eval() self.G_BtoA.eval() if self.isTrain: self.D_A.eval() self.D_B.eval() def compute_visuals(self, bidirectional=False): """ Computes the Visual output data from the model :type bidirectional: bool :param bidirectional: if true, Calculate both AtoB and BtoA, else calculate AtoB """ self.eval() with torch.no_grad(): self.fake_B = self.G_AtoB(self.real_A) if bidirectional: self.fake_A = self.G_BtoA(self.real_B) def _load_objects(self, file_names: List[str], object_names: List[str]): """Load objects from file :param file_names: Name of the Files to load :param object_names: Name of the object, where the files is going to be stored. file_names and object_names should be in same order """ for file_name, object_name in zip(file_names, object_names): model_name = os.path.join(self.save_dir, file_name) print(f"Loading {object_name} from {model_name}") state_dict = torch.load(model_name, map_location=self.device) net = getattr(self, object_name) if isinstance(net, torch.nn.DataParallel): net = net.module net.load_state_dict(state_dict) def load_networks(self, initials, load_D=False): """ Loading Models Loads from /checkpoint_dir/name/{initials}_net_G_AtoB.pt :type initials: str :param initials: The initials of the model :type load_D: bool :param load_D: Is loading D or not """ file_names = [f"{initials}_net_G_AtoB.pt", f"{initials}_net_G_BtoA.pt"] if load_D: file_names.append(f"{initials}_net_D_A.pt") file_names.append(f"{initials}_net_D_B.pt") object_names = ['G_AtoB', 'G_BtoA'] if not load_D else [ 'G_AtoB', 'G_BtoA', 'D_A', 'D_B' ] self._load_objects(file_names, object_names) def load_lr_schedulers(self, initials): s_file_name_0 = os.path.join(self.save_dir, f"{initials}_scheduler_0.pt") s_file_name_1 = os.path.join(self.save_dir, f"{initials}_scheduler_1.pt") print(f"Loading scheduler-0 from {s_file_name_0}") self.schedulers[0].load_state_dict( torch.load(s_file_name_0, map_location=self.device)) print(f"Loading scheduler-1 from {s_file_name_1}") self.schedulers[1].load_state_dict( torch.load(s_file_name_1, map_location=self.device)) def load_train_model(self, initials): """ Loading Models for training purpose :type initials: str :param initials: Initials of the object names """ self.load_networks(initials, load_D=True) optim_file_names = [f"{initials}_optim_G.pt", f"{initials}_optim_D.pt"] optim_object_names = ['optimizer_G', 'optimizer_D'] self._load_objects(optim_file_names, optim_object_names) self.load_lr_schedulers(initials) def save_networks(self, epoch): """Save models :type epoch: str :param epoch: Current Epoch (prefix for the name) """ for net_name in self.net_names: net = getattr(self, net_name) self.save_network(net, net_name, epoch) def save_optimizers_and_scheduler(self, epoch): """Save optimizers :type epoch: str :param epoch: Current Epoch (prefix for the name) """ # Saving Optimizers self.save_optimizer_scheduler(self.optimizer_G, f"{epoch}_optim_G.pt") self.save_optimizer_scheduler(self.optimizer_D, f"{epoch}_optim_D.pt") # Saving Schedulers self.save_optimizer_scheduler(self.schedulers[0], f"{epoch}_scheduler_0.pt") self.save_optimizer_scheduler(self.schedulers[1], f"{epoch}_scheduler_1.pt") def save_optimizer_scheduler(self, optim_or_scheduler, name): """Save a single optimizer :param optim_or_scheduler: The optimizer object :type name: str :param name: Name of the optimizer """ save_path = os.path.join(self.save_dir, name) torch.save(optim_or_scheduler.state_dict(), save_path) def save_network(self, net, net_name, epoch): save_filename = '%s_net_%s.pt' % (epoch, net_name) if self.opt.isCloud: save_path = save_filename else: save_path = os.path.join(self.save_dir, save_filename) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.module.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) def get_current_losses(self) -> dict: """Get the Current Losses :return: Losses """ if isinstance(self.loss_idt_A, (int, float)): idt_loss_A = self.loss_idt_A else: idt_loss_A = self.loss_idt_A.item() if isinstance(self.loss_idt_B, (int, float)): idt_loss_B = self.loss_idt_B else: idt_loss_B = self.loss_idt_B.item() return collections.OrderedDict({ 'loss_idt_A': idt_loss_A, 'loss_idt_B': idt_loss_B, 'loss_D_A': self.loss_D_A.item(), 'loss_D_B': self.loss_D_B.item(), 'loss_G_AtoB': self.loss_G_AtoB.item(), 'loss_G_BtoA': self.loss_G_BtoA.item(), 'cycle_loss_A': self.cycle_loss_A.item(), 'cycle_loss_B': self.cycle_loss_B.item() }) def get_current_image_path(self): """ :return: The current image path """ return self.image_paths def get_current_visuals(self): """Get the Current Produced Images :return: Images {real_A, real_B, fake_A, fake_B, rec_A, rec_B} :rtype: dict """ r = collections.OrderedDict({ 'real_A': self.real_A, 'real_B': self.real_B }) if self.fake_A is not None: r['fake_A'] = self.fake_A if self.fake_B is not None: r['fake_B'] = self.fake_B if self.rec_A is not None: r['rec_A'] = self.rec_A if self.rec_B is not None: r['rec_B'] = self.rec_B return r
def create_image_pools(data_pool_size): fake_A_pool = ImagePool(pool_size=data_pool_size) fake_B_pool = ImagePool(pool_size=data_pool_size) return fake_A_pool, fake_B_pool
def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert opt.isTrain assert opt.direction == 'AtoB' assert opt.dataset_mode == 'unaligned' BaseModel.__init__(self, opt) self.loss_names = [ 'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B', 'G_idt_B' ] visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.opt.lambda_identity > 0.0: visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) if opt.lambda_identity > 0.0: assert (opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader_AtoB = create_eval_dataloader(self.opt, direction='AtoB') self.eval_dataloader_BtoA = create_eval_dataloader(self.opt, direction='BtoA') block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() self.best_fid_A, self.best_fid_B = 1e9, 1e9 self.best_mIoU = -1e9 self.fids_A, self.fids_B = [], [] self.mIoUs = [] self.is_best_A = False self.is_best_B = False self.npz_A = np.load(opt.real_stat_A_path) self.npz_B = np.load(opt.real_stat_B_path)
def train(self): """ Train the MaskShadowGAN model by starting from a saved checkpoint or from the beginning. """ if self.opt.load_model is not None: checkpoint = 'checkpoints/' + self.opt.load_model else: checkpoint_name = datetime.now().strftime("%d%m%Y-%H%M") checkpoint = 'checkpoints/{}'.format(checkpoint_name) try: os.makedirs(checkpoint) except os.error: print("Failed to make new checkpoint directory.") sys.exit(1) # build the Mask-ShadowGAN graph graph = tf.Graph() with graph.as_default(): maskshadowgan = MaskShadowGANModel(self.opt, training=True) dataA_iter, dataB_iter, realA, realB = maskshadowgan.generate_dataset( ) fakeA, fakeB, optimizers, Gen_loss, D_A_loss, D_B_loss = maskshadowgan.build( ) saver = tf.train.Saver(max_to_keep=2) summary = tf.summary.merge_all() writer = tf.summary.FileWriter(checkpoint, graph) # create image pools for holding previously generated images fakeA_pool = ImagePool(self.opt.pool_size) fakeB_pool = ImagePool(self.opt.pool_size) # create queue to hold generated shadow masks mask_queue = MaskQueue(self.opt.queue_size) with tf.Session(graph=graph) as sess: if self.opt.load_model is not None: # restore graph and variables saver.restore(sess, tf.train.latest_checkpoint(checkpoint)) ckpt = tf.train.get_checkpoint_state(checkpoint) step = int( os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) else: sess.run(tf.global_variables_initializer()) step = 0 max_steps = self.opt.niter + self.opt.niter_decay # initialize data iterators sess.run([dataA_iter.initializer, dataB_iter.initializer]) try: while step < max_steps: try: realA_img, realB_img = sess.run([realA, realB ]) # fetch inputs # generate shadow free image from shadow image fakeB_img = sess.run( fakeB, feed_dict={maskshadowgan.realA: realA_img}) # generate shadow mask and add to mask queue mask_queue.insert(mask_generator(realA_img, fakeB_img)) rand_mask = mask_queue.rand_item() # generate shadow image from shadow free image and shadow mask fakeA_img = sess.run(fakeA, feed_dict={ maskshadowgan.realB: realB_img, maskshadowgan.rand_mask: rand_mask }) # calculate losses for the generators and discriminators and minimize them _, Gen_loss_val, D_B_loss_val, \ D_A_loss_val, sum = sess.run([optimizers, Gen_loss, D_B_loss, D_A_loss, summary], feed_dict={maskshadowgan.realA: realA_img, maskshadowgan.realB: realB_img, maskshadowgan.rand_mask: rand_mask, maskshadowgan.last_mask: mask_queue.last_item(), maskshadowgan.fakeA: fakeA_pool.query(fakeA_img), maskshadowgan.fakeB: fakeB_pool.query(fakeB_img)}) writer.add_summary(sum, step) writer.flush() # display the losses of the Generators and Discriminators if step % self.opt.display_frequency == 0: print('Step {}:'.format(step)) print('Gen_loss: {}'.format(Gen_loss_val)) print('D_B_loss: {}'.format(D_B_loss_val)) print('D_A_loss: {}'.format(D_A_loss_val)) # save a checkpoint of the model to the `checkpoints` directory if step % self.opt.checkpoint_frequency == 0: save_path = saver.save(sess, checkpoint + '/model.ckpt', global_step=step) print("Model saved as {}".format(save_path)) step += 1 except tf.errors.OutOfRangeError: # reinitializer iterators every full pass through dataset sess.run( [dataA_iter.initializer, dataB_iter.initializer]) except KeyboardInterrupt: # save training before exiting print( "Saving models training progress to the `checkpoints` directory..." ) save_path = saver.save(sess, checkpoint + '/model.ckpt', global_step=step) print("Model saved as {}".format(save_path)) sys.exit(0)
class CycleMultiDModel(CycleGANModel): def name(self): return 'CycleMultiDModel' def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) if self.isTrain: use_sigmoid = opt.no_lsgan and not opt.no_sigmoid self.netD_1A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, one_out=True, opt=opt) self.netD_1B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, one_out=True, opt=opt) self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, one_out=False, opt=opt) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, one_out=False, opt=opt) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_1A', which_epoch) self.load_network(self.netD_B, 'D_1B', which_epoch) self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_1A = torch.optim.Adam(self.netD_1A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_1B = torch.optim.Adam(self.netD_1B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) self.optimizers.append(self.optimizer_D_1A) self.optimizers.append(self.optimizer_D_1B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A, opt) if self.isTrain: networks.print_network(self.netD_A, opt) networks.print_network(self.netD_1A, opt) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG_A.forward(self.real_A) self.rec_A = self.netG_B.forward(self.fake_B) self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD.forward(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD.forward(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D_real, loss_D_fake def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A_real, self.loss_D_A_fake = self.backward_D_basic( self.netD_A, self.real_B, fake_B) self.loss_D_1A_real, self.loss_D_1A_fake = self.backward_D_basic( self.netD_1A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B_real, self.loss_D_B_fake = self.backward_D_basic( self.netD_B, self.real_A, fake_A) self.loss_D_1B_real, self.loss_D_1B_fake = self.backward_D_basic( self.netD_1B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity lambda_rec = self.opt.lambda_rec lambda_adv = self.opt.lambda_adv # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A.forward(self.real_B) self.loss_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_rec * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B.forward(self.real_A) self.loss_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_rec * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(self.fake_B) pred_1fake = self.netD_1A.forward(self.fake_B) self.loss_G_A = (self.criterionGAN(pred_fake, True) + self.criterionGAN(pred_1fake, True)) * lambda_adv # D_B(G_B(B)) self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) pred_1fake = self.netD_1B.forward(self.fake_A) self.loss_G_B = (self.criterionGAN(pred_fake, True) + self.criterionGAN(pred_1fake, True)) * lambda_adv # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_rec # Backward cycle loss self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_rec # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): D_A = self.loss_D_A_real.data[0] + self.loss_D_A_fake.data[0] D_1A = self.loss_D_1A_real.data[0] + self.loss_D_1A_fake.data[0] G_A = self.loss_G_A.data[0] G_A = self.loss_G_A.data[0] Cyc_A = self.loss_cycle_A.data[0] D_B = self.loss_D_B_real.data[0] + self.loss_D_B_fake.data[0] D_1B = self.loss_D_1B_real.data[0] + self.loss_D_1B_fake.data[0] G_B = self.loss_G_B.data[0] Cyc_B = self.loss_cycle_B.data[0] if self.opt.identity > 0.0: idt_A = self.loss_idt_A.data[0] idt_B = self.loss_idt_B.data[0] return OrderedDict([('D_A', D_A), ('D_1A', D_1A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', D_B), ('D_1B', D_1B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: return OrderedDict([('D_A', D_A), ('D_1A', D_1A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B', D_B), ('D_1B', D_1B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) def get_current_lr(self): lr_A = self.optimizer_D_A.param_groups[0]['lr'] lr_B = self.optimizer_D_B.param_groups[0]['lr'] lr_G = self.optimizer_G.param_groups[0]['lr'] return OrderedDict([('D_A', lr_A), ('D_B', lr_B), ('G', lr_G)]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) rec_A = util.tensor2im(self.rec_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) rec_B = util.tensor2im(self.rec_B.data) if self.opt.isTrain and self.opt.identity > 0.0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) else: return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) def get_network_params(self): return [('G_A', util.get_params(self.netG_A)), ('G_B', util.get_params(self.netG_B)), ('D_A', util.get_params(self.netD_A)), ('D_B', util.get_params(self.netD_B)), ('D_1A', util.get_params(self.netD_1A)), ('D_1B', util.get_params(self.netD_1B))] def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) self.save_network(self.netD_1A, 'D_1A', label, self.gpu_ids) self.save_network(self.netD_1B, 'D_1B', label, self.gpu_ids)
def do_train(Cfg, model_G, model_Dip, model_Dii, model_D_reid, train_loader, val_loader, optimizerG, optimizerDip, optimizerDii, GAN_loss, L1_loss, ReID_loss, schedulerG, schedulerDip, schedulerDii): log_period = Cfg.SOLVER.LOG_PERIOD checkpoint_period = Cfg.SOLVER.CHECKPOINT_PERIOD eval_period = Cfg.SOLVER.EVAL_PERIOD output_dir = Cfg.DATALOADER.LOG_DIR # need modified the following in cfg epsilon = 0.00001 margin = 0.4 #################################### device = "cuda" epochs = Cfg.SOLVER.MAX_EPOCHS logger = logging.getLogger('pose-transfer-gan.train') logger.info('Start training') if device: if torch.cuda.device_count() > 1: print('Using {} GPUs for training'.format( torch.cuda.device_count())) model_G = nn.DataParallel(model_G) model_Dii = nn.DataParallel(model_Dii) model_Dip = nn.DataParallel(model_Dip) model_G.to(device) model_Dip.to(device) model_Dii.to(device) model_D_reid.to(device) lossG_meter = AverageMeter() lossDip_meter = AverageMeter() lossDii_meter = AverageMeter() distDreid_meter = AverageMeter() fake_ii_pool = ImagePool(50) fake_ip_pool = ImagePool(50) #evaluator = R1_mAP(num_query, max_rank=50, feat_norm=Cfg.TEST.FEAT_NORM) #train for epoch in range(1, epochs + 1): start_time = time.time() lossG_meter.reset() lossDip_meter.reset() lossDii_meter.reset() distDreid_meter.reset() schedulerG.step() schedulerDip.step() schedulerDii.step() model_G.train() model_Dip.train() model_Dii.train() model_D_reid.eval() for iter, batch in enumerate(train_loader): img1 = batch['img1'].to(device) pose1 = batch['pose1'].to(device) img2 = batch['img2'].to(device) pose2 = batch['pose2'].to(device) input_G = (img1, pose2) #forward fake_img2 = model_G(input_G) optimizerG.zero_grad() #train G input_Dip = torch.cat((fake_img2, pose2), 1) pred_fake_ip = model_Dip(input_Dip) loss_G_ip = GAN_loss(pred_fake_ip, True) input_Dii = torch.cat((fake_img2, img1), 1) pred_fake_ii = model_Dii(input_Dii) loss_G_ii = GAN_loss(pred_fake_ii, True) loss_L1, _, _ = L1_loss(fake_img2, img2) feats_real = model_D_reid(img2) feats_fake = model_D_reid(fake_img2) dist_cos = torch.acos( torch.clamp(torch.sum(feats_real * feats_fake, 1), -1 + epsilon, 1 - epsilon)) same_id_tensor = torch.FloatTensor( dist_cos.size()).fill_(1).to('cuda') dist_cos_margin = torch.max(dist_cos - margin, torch.zeros_like(dist_cos)) loss_reid = ReID_loss(dist_cos_margin, same_id_tensor) factor = loss_reid_factor(epoch) loss_G = 0.5 * loss_G_ii * Cfg.LOSS.GAN_WEIGHT + 0.5 * loss_G_ip * Cfg.LOSS.GAN_WEIGHT + loss_L1 + loss_reid * Cfg.LOSS.REID_WEIGHT * factor loss_G.backward() optimizerG.step() #train Dip for i in range(Cfg.SOLVER.DG_RATIO): optimizerDip.zero_grad() real_input_ip = torch.cat((img2, pose2), 1) fake_input_ip = fake_ip_pool.query( torch.cat((fake_img2, pose2), 1).data) pred_real_ip = model_Dip(real_input_ip) loss_Dip_real = GAN_loss(pred_real_ip, True) pred_fake_ip = model_Dip(fake_input_ip) loss_Dip_fake = GAN_loss(pred_fake_ip, False) loss_Dip = 0.5 * Cfg.LOSS.GAN_WEIGHT * (loss_Dip_real + loss_Dip_fake) loss_Dip.backward() optimizerDip.step() #train Dii for i in range(Cfg.SOLVER.DG_RATIO): optimizerDii.zero_grad() real_input_ii = torch.cat((img2, img1), 1) fake_input_ii = fake_ii_pool.query( torch.cat((fake_img2, img1), 1).data) pred_real_ii = model_Dii(real_input_ii) loss_Dii_real = GAN_loss(pred_real_ii, True) pred_fake_ii = model_Dii(fake_input_ii) loss_Dii_fake = GAN_loss(pred_fake_ii, False) loss_Dii = 0.5 * Cfg.LOSS.GAN_WEIGHT * (loss_Dii_real + loss_Dii_fake) loss_Dii.backward() optimizerDii.step() lossG_meter.update(loss_G.item(), 1) lossDip_meter.update(loss_Dip.item(), 1) lossDii_meter.update(loss_Dii.item(), 1) distDreid_meter.update(dist_cos.mean().item(), 1) if (iter + 1) % log_period == 0: logger.info( "Epoch[{}] Iteration[{}/{}] G Loss: {:.3f}, Dip Loss: {:.3f}, Dii Loss: {:.3f}, Base G_Lr: {:.2e}, Base Dip_Lr: {:.2e}, Base Dii_Lr: {:.2e}" .format(epoch, (iter + 1), len(train_loader), lossG_meter.avg, lossDip_meter.avg, lossDii_meter.avg, schedulerG.get_lr()[0], schedulerDip.get_lr()[0], schedulerDii.get_lr()[0])) #scheduler.get_lr()[0] logger.info("ReID Cos Distance: {:.3f}".format( distDreid_meter.avg)) end_time = time.time() time_per_batch = (end_time - start_time) / (iter + 1) logger.info( "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]" .format(epoch, time_per_batch, train_loader.batch_size / time_per_batch)) if epoch % checkpoint_period == 0: torch.save(model_G.state_dict(), output_dir + 'model_G_{}.pth'.format(epoch)) torch.save(model_Dip.state_dict(), output_dir + 'model_Dip_{}.pth'.format(epoch)) torch.save(model_Dii.state_dict(), output_dir + 'model_Dii_{}.pth'.format(epoch)) # if epoch % eval_period == 0: np.save(output_dir + 'train_Bx6x128x64_epoch{}.npy'.format(epoch), fake_ii_pool.images[0].cpu().numpy()) logger.info('Entering Evaluation...') tmp_results = [] model_G.eval() for iter, batch in enumerate(val_loader): with torch.no_grad(): img1 = batch['img1'].to(device) pose1 = batch['pose1'].to(device) img2 = batch['img2'].to(device) pose2 = batch['pose2'].to(device) input_G = (img1, pose2) fake_img2 = model_G(input_G) tmp_result = torch.cat((img1, img2, fake_img2), 1).cpu().numpy() tmp_results.append(tmp_result) np.save(output_dir + 'test_Bx6x128x64_epoch{}.npy'.format(epoch), tmp_results[0])
def initialize(self, opt): BaseModel.initialize(self, opt) nb = opt.batchSize size = opt.fineSize if isinstance(opt.weight_adv, list): self.weight_adv = map(float, opt.weight_adv) else: self.weight_adv = None if isinstance(opt.weight_rec, list): self.weight_rec = map(float, opt.weight_rec) else: self.weight_rec = None self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) if not opt.idt: self.down_2 = torch.nn.AvgPool2d(2) self.up_2 = torch.nn.Upsample(scale_factor=2, mode='bilinear') else: self.down_2 = torch.nn.AvgPool2d(1) self.up_2 = torch.nn.AvgPool2d(1) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, n_upsampling=opt.n_upsample, n_downsampling=opt.n_downsample, side='A', opt=opt) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, n_upsampling=opt.n_upsample, n_downsampling=opt.n_downsample, side='B', opt=opt) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt) print('---------- Networks initialized -------------') networks.print_network(self.netG_A, opt, input_shape=(opt.input_nc, opt.fineSize, opt.fineSize)) if self.isTrain: networks.print_network(self.netD_A, opt, input_shape=(3, opt.fineSize, opt.fineSize)) print('-----------------------------------------------') if not self.isTrain or opt.continue_train: print 'Continue from ', opt.which_epoch which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain and not opt.test: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, target_weight=self.weight_adv) self.criterionCycle = networks.RECLoss( target_weight=self.weight_rec) # initialize optimizers self.optimizer_G_A = torch.optim.Adam(self.netG_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G_B = torch.optim.Adam(self.netG_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.d_lr2: self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=(opt.lr / 2.0), betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=(opt.lr / 2.0), betas=(opt.beta1, 0.999)) else: self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G_A) self.optimizers.append(self.optimizer_G_B) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt))
class Trainer(object): def __init__(self, cuda, model, optimizer,loss_fun, train_loader,test_loader,lmk_num,view,crossentropy_weight, out, max_epoch, network_num,batch_size,GAN, do_classification=True,do_landmarkdetect=True, size_average=False, interval_validate=None, compete = False,onlyEval=False): self.cuda = cuda self.model = model self.optim = optimizer self.train_loader = train_loader self.test_loader = test_loader self.interval_validate = interval_validate self.network_num = network_num self.do_classification = do_classification self.do_landmarkdetect = do_landmarkdetect self.crossentropy_weight = crossentropy_weight self.timestamp_start = \ datetime.datetime.now(pytz.timezone('Asia/Tokyo')) self.size_average = size_average self.out = out if not osp.exists(self.out): os.makedirs(self.out) self.lmk_num = lmk_num self.GAN = GAN self.onlyEval = onlyEval if self.GAN: GAN_lr = 0.0002 input_nc = 1 output_nc = self.lmk_num ndf = 64 norm_layer = torchsrc.models.get_norm_layer(norm_type='batch') gpu_ids = [0] self.netD = torchsrc.models.NLayerDiscriminator(input_nc+output_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=True, gpu_ids=gpu_ids) self.optimizer_D = torch.optim.Adam(self.netD.parameters(),lr=GAN_lr, betas=(0.5, 0.999)) self.netD.cuda() self.netD.apply(torchsrc.models.weights_init) pool_size = 10 self.fake_AB_pool = ImagePool(pool_size) no_lsgan = True self.Tensor = torch.cuda.FloatTensor if gpu_ids else torch.Tensor self.criterionGAN = torchsrc.models.GANLoss(use_lsgan=not no_lsgan, tensor=self.Tensor) self.max_epoch = max_epoch self.epoch = 0 self.iteration = 0 self.best_mean_iu = 0 self.compete = compete self.batch_size = batch_size self.view = view self.loss_fun = loss_fun def forward_step(self, data, category_name): if category_name == 'KidneyLong': pred_lmk = self.model(data, 'KidneyLong') elif category_name == 'KidneyTrans': pred_lmk = self.model(data, 'KidneyTrans') elif category_name == 'LiverLong': pred_lmk = self.model(data, 'LiverLong') elif category_name == 'SpleenLong': pred_lmk = self.model(data, 'SpleenLong') elif category_name == 'SpleenTrans': pred_lmk = self.model(data, 'SpleenTrans') return pred_lmk def backward_D(self,real_A,real_B,fake_B): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((real_A, fake_B), 1)) pred_fake = self.netD.forward(fake_AB.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((real_A, real_B), 1) pred_real = self.netD.forward(real_AB) loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (loss_D_fake + loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self,real_A,fake_B): # First, G(A) should fake the discriminator fake_AB = torch.cat((real_A, fake_B), 1) pred_fake = self.netD.forward(fake_AB) loss_G_GAN = self.criterionGAN(pred_fake, True) return loss_G_GAN def validate(self): self.model.train() out = osp.join(self.out, 'seg_output') out_vis = osp.join(self.out, 'visualization') results_epoch_dir = osp.join(out,'epoch_%04d' % self.epoch) mkdir(results_epoch_dir) results_vis_epoch_dir = osp.join(out_vis, 'epoch_%04d' % self.epoch) mkdir(results_vis_epoch_dir) prev_sub_name = 'start' prev_view_name = 'start' for batch_idx, (data,target,target2ch,sub_name,view,img_name) in tqdm.tqdm( # enumerate(self.test_loader), total=len(self.test_loader), enumerate(self.test_loader), total=len(self.test_loader), desc='Valid epoch=%d' % self.epoch, ncols=80, leave=False): # if batch_idx>1000: # return # if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data,volatile=True), Variable(target,volatile=True) # need_to_run = False # for sk in range(len(sub_name)): # batch_finish_flag = os.path.join(results_epoch_dir, sub_name[sk], ('%s_%s.nii.gz' % (sub_name[sk], view[sk]))) # if not (os.path.exists(batch_finish_flag)): # need_to_run = True # if not need_to_run: # continue # pred = self.model(data) # imgs = data.data.cpu() lbl_pred = pred.data.max(1)[1].cpu().numpy()[:, :, :] batch_num = lbl_pred.shape[0] for si in range(batch_num): curr_sub_name = sub_name[si] curr_view_name = view[si] curr_img_name = img_name[si] # out_img_dir = os.path.join(results_epoch_dir, curr_sub_name) # finish_flag = os.path.join(out_img_dir,('%s_%s.nii.gz'%(curr_sub_name,curr_view_name))) # if os.path.exists(finish_flag): # prev_sub_name = 'start' # prev_view_name = 'start' # continue if prev_sub_name == 'start': if self.view == 'viewall': seg = np.zeros([512,512,512], np.uint8) else: seg = np.zeros([512,512,1000],np.uint8) slice_num = 0 elif not(prev_sub_name==curr_sub_name and prev_view_name==curr_view_name): out_img_dir = os.path.join(results_epoch_dir, prev_sub_name) mkdir(out_img_dir) out_nii_file = os.path.join(out_img_dir,('%s_%s.nii.gz'%(prev_sub_name,prev_view_name))) seg_img = nib.Nifti1Image(seg, affine=np.eye(4)) nib.save(seg_img, out_nii_file) if self.view == 'viewall': seg = np.zeros([512,512,512], np.uint8) else: seg = np.zeros([512,512,1000],np.uint8) slice_num = 0 test_slice_name = ('slice_%04d.png'%(slice_num+1)) assert test_slice_name == curr_img_name seg_slice = lbl_pred[si, :, :].astype(np.uint8) seg_slice = scipy.misc.imresize(seg_slice, (512, 512), interp='nearest') if curr_view_name == 'view1': seg[slice_num,:,:] = seg_slice elif curr_view_name == 'view2': seg[:,slice_num,:] = seg_slice elif curr_view_name == 'view3': seg[:, :, slice_num] = seg_slice slice_num+=1 prev_sub_name = curr_sub_name prev_view_name = curr_view_name out_img_dir = os.path.join(results_epoch_dir, curr_sub_name) mkdir(out_img_dir) out_nii_file = os.path.join(out_img_dir, ('%s_%s.nii.gz' % (curr_sub_name, curr_view_name))) seg_img = nib.Nifti1Image(seg, affine=np.eye(4)) nib.save(seg_img, out_nii_file) # out_img_dir = os.path.join(results_epoch_dir, sub_name[si], view[si]) # mkdir(out_img_dir) # out_mat_file = os.path.join(out_img_dir,img_name[si].replace('.png','.mat')) # if not os.path.exists(out_mat_file): # out_dict = {} # out_dict["sub_name"] = sub_name[si] # out_dict["view"] = view[si] # out_dict['img_name'] = img_name[si].replace('.png','.mat') # out_dict["seg"] = seg # sio.savemat(out_mat_file, out_dict) # if not(sub_name[0] == '010-006-001'): # continue # # lbl_true = target.data.cpu() # for img, lt, lp, name, view, fname in zip(imgs, lbl_true, lbl_pred,sub_name,view,img_name): # img, lt = self.test_loader.dataset.untransform(img, lt) # if lt.sum()>5000: # viz = fcn.utils.visualize_segmentation( # lbl_pred = lp, lbl_true = lt, img = img, n_class=2) # out_img_dir = os.path.join(results_vis_epoch_dir,name,view) # mkdir(out_img_dir) # out_img_file = os.path.join(out_img_dir,fname) # if not (os.path.exists(out_img_file)): # skimage.io.imsave(out_img_file, viz) def train(self): self.model.train() out = osp.join(self.out, 'visualization') mkdir(out) log_file = osp.join(out, 'training_loss.txt') fv = open(log_file, 'a') for batch_idx, (data, target, target2ch, sub_name, view, img_name) in tqdm.tqdm( enumerate(self.train_loader), total=len(self.train_loader), desc='Train epoch=%d' % self.epoch, ncols=80, leave=False): #iteration = batch_idx + self.epoch * len(self.lmk_train_loader) # if not(sub_name[0] == '006-002-003' and view[0] =='view3' and img_name[0] == 'slice_0288.png'): # continue if self.cuda: data, target, target2ch = data.cuda(), target.cuda(), target2ch.cuda() data, target, target2ch = Variable(data), Variable(target), Variable(target2ch) pred = self.model(data) self.optim.zero_grad() if self.GAN: self.optimizer_D.zero_grad() self.backward_D(data,target2ch,pred) self.optimizer_D.step() loss_G_GAN = self.backward_G(data,pred) if self.loss_fun == 'cross_entropy': arr = np.array(self.crossentropy_weight) weight = torch.from_numpy(arr).cuda().float() loss_G_L2 = cross_entropy2d(pred, target.long(),weight=weight) elif self.loss_fun == 'Dice': loss_G_L2 = dice_loss(pred,target2ch) elif self.loss_fun == 'Dice_norm': loss_G_L2 = dice_loss_norm(pred, target2ch) loss = loss_G_GAN + loss_G_L2*100 fv.write('--- epoch=%d, batch_idx=%d, D_loss=%.4f, G_loss=%.4f, L2_loss = %.4f \n' % ( self.epoch, batch_idx, self.loss_D.data[0], loss_G_GAN.data[0],loss_G_L2.data[0] )) if batch_idx%10 == 0: print('--- epoch=%d, batch_idx=%d, D_loss=%.4f, G_loss=%.4f, L2_loss_loss = %.4f \n' % ( self.epoch, batch_idx, self.loss_D.data[0], loss_G_GAN.data[0],loss_G_L2.data[0] )) else: if self.loss_fun == 'cross_entropy': arr = np.array(self.crossentropy_weight) weight = torch.from_numpy(arr).cuda().float() loss = cross_entropy2d(pred, target.long(),weight=weight) elif self.loss_fun == 'Dice': loss = dice_loss(pred,target2ch) elif self.loss_fun == 'Dice_norm': loss = dice_loss_norm(pred, target2ch) loss.backward() self.optim.step() if batch_idx % 10 == 0: print('epoch=%d, batch_idx=%d, loss=%.4f \n'%(self.epoch,batch_idx,loss.data[0])) fv.write('epoch=%d, batch_idx=%d, loss=%.4f \n'%(self.epoch,batch_idx,loss.data[0])) fv.close() def train_epoch(self): for epoch in tqdm.trange(self.epoch, self.max_epoch, desc='Train', ncols=80): self.epoch = epoch out = osp.join(self.out, 'models', self.view) mkdir(out) model_pth = '%s/model_epoch_%04d.pth' % (out, epoch) gan_model_pth = '%s/GAN_D_epoch_%04d.pth' % (out, epoch) if os.path.exists(model_pth): self.model.load_state_dict(torch.load(model_pth)) # if epoch == 9: # self.validate() # if self.onlyEval: # self.validate() if self.GAN and os.path.exists(gan_model_pth): self.netD.load_state_dict(torch.load(gan_model_pth)) else: if not self.onlyEval: self.train() self.validate() torch.save(self.model.state_dict(), model_pth) if self.GAN: torch.save(self.netD.state_dict(), gan_model_pth)