def create_metric_models(opt, device): if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] inception_model = InceptionV3([block_idx]) if len(opt.gpu_ids) > 1: inception_model = nn.DataParallel(inception_model, opt.gpu_ids) inception_model.to(device) inception_model.eval() else: inception_model = None if 'cityscapes' in opt.dataroot and opt.direction == 'BtoA': drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: drn_model = nn.DataParallel(drn_model, opt.gpu_ids) drn_model.to(device) drn_model.eval() else: drn_model = None if 'coco' in opt.dataroot and not opt.no_mIoU and opt.direction == 'BtoA': deeplabv2_model = MSC(DeepLabV2(n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]), scales=[0.5, 0.75]) util.load_network(deeplabv2_model, opt.deeplabv2_path, verbose=False) if len(opt.gpu_ids) > 1: deeplabv2_model = nn.DataParallel(deeplabv2_model, opt.gpu_ids) deeplabv2_model.to(device) deeplabv2_model.eval() else: deeplabv2_model = None return inception_model, drn_model, deeplabv2_model
class SPADEModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train): assert isinstance(parser, argparse.ArgumentParser) parser.set_defaults(netG='sub_mobile_spade') parser.add_argument('--separable_conv_norm', type=str, default='instance', choices=('none', 'instance', 'batch'), help='whether to use instance norm for the separable convolutions') parser.add_argument('--norm_G', type=str, default='spadesyncbatch3x3', help='instance normalization or batch normalization') parser.add_argument('--num_upsampling_layers', choices=('normal', 'more', 'most'), default='more', help="If 'more', adds upsampling layer between the two middle resnet blocks. " "If 'most', also add one more upsampling + resnet layer at the end of the generator") if is_train: parser.add_argument('--restore_G_path', type=str, default=None, help='the path to restore the generator') parser.add_argument('--restore_D_path', type=str, default=None, help='the path to restore the discriminator') parser.add_argument('--real_stat_path', type=str, required=True, help='the path to load the groud-truth images information to compute FID.') parser.add_argument('--lambda_gan', type=float, default=1, help='weight for gan loss') parser.add_argument('--lambda_feat', type=float, default=10, help='weight for gan feature loss') parser.add_argument('--lambda_vgg', type=float, default=10, help='weight for vgg loss') parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') parser.add_argument('--no_fid', action='store_true', help='No FID evaluation during training') parser.add_argument('--no_mIoU', action='store_true', help='No mIoU evaluation during training ' '(sometimes because there are CUDA memory)') parser.set_defaults(netD='multi_scale', ndf=64, dataset_mode='cityscapes', batch_size=16, print_freq=50, save_latest_freq=10000000000, save_epoch_freq=10, nepochs=100, nepochs_decay=100, init_type='xavier') parser = networks.modify_commandline_options(parser, is_train) return parser def __init__(self, opt): super(SPADEModel, self).__init__(opt) self.model_names = ['G'] self.visual_names = ['labels', 'fake_B', 'real_B'] self.modules = SPADEModelModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules if opt.isTrain: self.model_names.append('D') self.loss_names = ['G_gan', 'G_feat', 'G_vgg', 'D_real', 'D_fake'] self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers() self.optimizers = [self.optimizer_G, self.optimizer_D] if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot and not opt.no_mIoU: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) self.drn_model.to(self.device) self.drn_model.eval() self.eval_dataloader = create_eval_dataloader(self.opt) self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.npz = np.load(opt.real_stat_path) else: self.modules.eval() self.train_dataloader = create_train_dataloader(opt) def set_input(self, input): self.data = input self.image_paths = input['path'] self.labels = input['label'].to(self.device) self.input_semantics, self.real_B = self.preprocess_input(input) def test(self, config=None): with torch.no_grad(): self.forward(on_one_gpu=True, config=config) def preprocess_input(self, data): # move to GPU and change data types data['label'] = data['label'].long() data['label'] = data['label'].to(self.device) data['instance'] = data['instance'].to(self.device) data['image'] = data['image'].to(self.device) # create one-hot label map label_map = data['label'] bs, _, h, w = label_map.size() nc = self.opt.input_nc + 1 if self.opt.contain_dontcare_label \ else self.opt.input_nc input_label = torch.zeros([bs, nc, h, w], device=self.device) input_semantics = input_label.scatter_(1, label_map, 1.0) # concatenate instance map if it exists if not self.opt.no_instance: inst_map = data['instance'] instance_edge_map = self.get_edges(inst_map) input_semantics = torch.cat((input_semantics, instance_edge_map), dim=1) return input_semantics, data['image'] def forward(self, on_one_gpu=False, config=None): if config is not None: self.modules_on_one_gpu.config = config if on_one_gpu: self.fake_B = self.modules_on_one_gpu(self.input_semantics) else: self.fake_B = self.modules(self.input_semantics) def get_edges(self, t): edge = torch.zeros(t.size(), dtype=torch.uint8, device=self.device) edge[:, :, :, 1:] = edge[:, :, :, 1:] | ((t[:, :, :, 1:] != t[:, :, :, :-1]).byte()) edge[:, :, :, :-1] = edge[:, :, :, :-1] | ((t[:, :, :, 1:] != t[:, :, :, :-1]).byte()) edge[:, :, 1:, :] = edge[:, :, 1:, :] | ((t[:, :, 1:, :] != t[:, :, :-1, :]).byte()) edge[:, :, :-1, :] = edge[:, :, :-1, :] | ((t[:, :, 1:, :] != t[:, :, :-1, :]).byte()) return edge.float() def profile(self, config=None, verbose=True): if config is not None: self.modules_on_one_gpu.config = config macs, params = self.modules_on_one_gpu.profile(self.input_semantics[:1]) if verbose: print('MACs: %.3fG\tParams: %.3fM' % (macs / 1e9, params / 1e6), flush=True) return macs, params def backward_G(self): losses = self.modules(self.input_semantics, self.real_B, mode='G_loss') loss_G = losses['loss_G'].mean() for loss_name in self.loss_names: if loss_name.startswith('G'): setattr(self, 'loss_%s' % loss_name, losses[loss_name].detach().mean()) loss_G.backward() def backward_D(self): losses = self.modules(self.input_semantics, self.real_B, mode='D_loss') loss_D = losses['loss_D'].mean() for loss_name in self.loss_names: if loss_name.startswith('D'): setattr(self, 'loss_%s' % loss_name, losses[loss_name].detach().mean()) loss_D.backward() def optimize_parameters(self, steps): # self.forward() self.set_requires_grad(self.modules_on_one_gpu.netD, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() self.set_requires_grad(self.modules_on_one_gpu.netD, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() def evaluate_model(self, step): self.is_best = False save_dir = os.path.join(self.opt.log_dir, 'eval', str(step)) os.makedirs(save_dir, exist_ok=True) self.modules_on_one_gpu.netG.eval() torch.cuda.empty_cache() fakes, names = [], [] ret = {} cnt = 0 for i, data_i in enumerate(tqdm(self.eval_dataloader, desc='Eval ', position=2, leave=False)): self.set_input(data_i) self.test() fakes.append(self.fake_B.cpu()) for j in range(len(self.image_paths)): short_path = ntpath.basename(self.image_paths[j]) name = os.path.splitext(short_path)[0] names.append(name) if cnt < 10: input_im = util.tensor2label(self.input_semantics[j], self.opt.input_nc + 2) real_im = util.tensor2im(self.real_B[j]) fake_im = util.tensor2im(self.fake_B[j]) util.save_image(input_im, os.path.join(save_dir, 'input', '%s.png' % name), create_dir=True) util.save_image(real_im, os.path.join(save_dir, 'real', '%s.png' % name), create_dir=True) util.save_image(fake_im, os.path.join(save_dir, 'fake', '%s.png' % name), create_dir=True) cnt += 1 if not self.opt.no_fid: fid = get_fid(fakes, self.inception_model, self.npz, device=self.device, batch_size=self.opt.eval_batch_size, tqdm_position=2) if fid < self.best_fid: self.is_best = True self.best_fid = fid self.fids.append(fid) if len(self.fids) > 3: self.fids.pop(0) ret['metric/fid'] = fid ret['metric/fid-mean'] = sum(self.fids) / len(self.fids) ret['metric/fid-best'] = self.best_fid if 'cityscapes' in self.opt.dataroot and not self.opt.no_mIoU: mIoU = get_cityscapes_mIoU(fakes, names, self.drn_model, self.device, table_path=self.opt.table_path, data_dir=self.opt.cityscapes_path, batch_size=self.opt.eval_batch_size, num_workers=self.opt.num_threads, tqdm_position=2) if mIoU > self.best_mIoU: self.is_best = True self.best_mIoU = mIoU self.mIoUs.append(mIoU) if len(self.mIoUs) > 3: self.mIoUs = self.mIoUs[1:] ret['metric/mIoU'] = mIoU ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs) ret['metric/mIoU-best'] = self.best_mIoU self.modules_on_one_gpu.netG.train() torch.cuda.empty_cache() return ret def print_networks(self): print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self.modules_on_one_gpu, 'net' + name) num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) if hasattr(self.opt, 'log_dir'): with open(os.path.join(self.opt.log_dir, 'net' + name + '.txt'), 'w') as f: f.write(str(net) + '\n') f.write('[Network %s] Total number of parameters : %.3f M\n' % (name, num_params / 1e6)) print('-----------------------------------------------') def load_networks(self, verbose=True): self.modules_on_one_gpu.load_networks(verbose) if self.isTrain and self.opt.restore_O_path is not None: for i, optimizer in enumerate(self.optimizers): path = '%s-%d.pth' % (self.opt.restore_O_path, i) util.load_optimizer(optimizer, path, verbose) if self.opt.no_TTUR: G_lr, D_lr = self.opt.lr, self.opt.lr else: G_lr, D_lr = self.opt.lr / 2, self.opt.lr * 2 for param_group in self.optimizer_G.param_groups: param_group['lr'] = G_lr for param_group in self.optimizer_D.param_groups: param_group['lr'] = D_lr def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str) and hasattr(self, name): visual_ret[name] = getattr(self, name) return visual_ret def save_networks(self, epoch): self.modules_on_one_gpu.save_networks(epoch, self.save_dir) for i, optimizer in enumerate(self.optimizers): save_filename = '%s_optim-%d.pth' % (epoch, i) save_path = os.path.join(self.save_dir, save_filename) torch.save(optimizer.state_dict(), save_path) def calibrate(self, config): self.modules_on_one_gpu.netG.train() config = copy.deepcopy(config) for i, data in enumerate(self.train_dataloader): self.set_input(data) if i == 0: config['calibrate_bn'] = True self.modules_on_one_gpu.config = config self.modules(self.input_semantics, mode='calibrate') self.modules_on_one_gpu.netG.eval()
class Pix2PixModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): assert is_train parser = super(Pix2PixModel, Pix2PixModel).modify_commandline_options( parser, is_train) parser.add_argument('--restore_G_path', type=str, default=None, help='the path to restore the generator') parser.add_argument('--restore_D_path', type=str, default=None, help='the path to restore the discriminator') parser.add_argument('--recon_loss_type', type=str, default='l1', choices=['l1', 'l2', 'smooth_l1'], help='the type of the reconstruction loss') parser.add_argument('--lambda_recon', type=float, default=100, help='weight for reconstruction loss') parser.add_argument('--lambda_gan', type=float, default=1, help='weight for gan loss') parser.add_argument( '--real_stat_path', type=str, required=True, help= 'the path to load the groud-truth images information to compute FID.' ) return parser def __init__(self, opt): """Initialize the pix2pix class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert opt.isTrain BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['G_gan', 'G_recon', 'D_real', 'D_fake'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks> self.model_names = ['G', 'D'] # define networks (both generator and discriminator) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) # define loss functions self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) if opt.recon_loss_type == 'l1': self.criterionRecon = torch.nn.L1Loss() elif opt.recon_loss_type == 'l2': self.criterionRecon = torch.nn.MSELoss() elif opt.recon_loss_type == 'smooth_l1': self.criterionRecon = torch.nn.SmoothL1Loss() else: raise NotImplementedError( 'Unknown reconstruction loss type [%s]!' % opt.loss_type) # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader = create_eval_dataloader(self.opt) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: self.drn_model.to(self.device) self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids) self.drn_model.eval() self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.Tacts, self.Sacts = {}, {} self.npz = np.load(opt.real_stat_path) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap images in domain A and domain B. """ AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG(self.real_A) # G(A) def backward_D(self): """Calculate GAN loss for the discriminator""" fake_AB = torch.cat((self.real_A, self.fake_B), 1).detach() real_AB = torch.cat((self.real_A, self.real_B), 1).detach() pred_fake = self.netD(fake_AB) self.loss_D_fake = self.criterionGAN(pred_fake, False, for_discriminator=True) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True, for_discriminator=True) self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): """Calculate GAN and L1 loss for the generator""" # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_gan = self.criterionGAN( pred_fake, True, for_discriminator=False) * self.opt.lambda_gan # Second, G(A) = B self.loss_G_recon = self.criterionRecon( self.fake_B, self.real_B) * self.opt.lambda_recon # combine loss and calculate gradients self.loss_G = self.loss_G_gan + self.loss_G_recon self.loss_G.backward() def optimize_parameters(self, steps): self.forward() # compute fake images: G(A) # update D self.set_requires_grad(self.netD, True) # enable backprop for D self.optimizer_D.zero_grad() # set D's gradients to zero self.backward_D() # calculate gradients for D self.optimizer_D.step() # update D's weights # update G self.set_requires_grad( self.netD, False) # D requires no gradients when optimizing G self.optimizer_G.zero_grad() # set G's gradients to zero self.backward_G() # calculate graidents for G self.optimizer_G.step() # udpate G's weights def evaluate_model(self, step): self.is_best = False save_dir = os.path.join(self.opt.log_dir, 'eval', str(step)) os.makedirs(save_dir, exist_ok=True) self.netG.eval() fakes, names = [], [] cnt = 0 for i, data_i in enumerate( tqdm(self.eval_dataloader, desc='Eval ', position=2, leave=False)): self.set_input(data_i) self.test() fakes.append(self.fake_B.cpu()) for j in range(len(self.image_paths)): short_path = ntpath.basename(self.image_paths[j]) name = os.path.splitext(short_path)[0] names.append(name) if cnt < 10: input_im = util.tensor2im(self.real_A[j]) real_im = util.tensor2im(self.real_B[j]) fake_im = util.tensor2im(self.fake_B[j]) util.save_image(input_im, os.path.join(save_dir, 'input', '%s.png' % name), create_dir=True) util.save_image(real_im, os.path.join(save_dir, 'real', '%s.png' % name), create_dir=True) util.save_image(fake_im, os.path.join(save_dir, 'fake', '%s.png' % name), create_dir=True) cnt += 1 fid = get_fid(fakes, self.inception_model, self.npz, device=self.device, batch_size=self.opt.eval_batch_size, tqdm_position=2) if fid < self.best_fid: self.is_best = True self.best_fid = fid self.fids.append(fid) if len(self.fids) > 3: self.fids.pop(0) ret = { 'metric/fid': fid, 'metric/fid-mean': sum(self.fids) / len(self.fids), 'metric/fid-best': self.best_fid } if 'cityscapes' in self.opt.dataroot: mIoU = get_cityscapes_mIoU(fakes, names, self.drn_model, self.device, table_path=self.opt.table_path, data_dir=self.opt.cityscapes_path, batch_size=self.opt.eval_batch_size, num_workers=self.opt.num_threads, tqdm_position=2) if mIoU > self.best_mIoU: self.is_best = True self.best_mIoU = mIoU self.mIoUs.append(mIoU) if len(self.mIoUs) > 3: self.mIoUs = self.mIoUs[1:] ret['metric/mIoU'] = mIoU ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs) ret['metric/mIoU-best'] = self.best_mIoU self.netG.train() return ret
def main(configs, opt, gpu_id, queue, verbose): opt.gpu_ids = [gpu_id] dataloader = create_dataloader(opt, verbose) model = create_model(opt, verbose) model.setup(opt, verbose) device = model.device if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] inception_model = InceptionV3([block_idx]) inception_model.to(device) inception_model.eval() if 'cityscapes' in opt.dataroot and opt.direction == 'BtoA': drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: drn_model = nn.DataParallel(drn_model, opt.gpu_ids) drn_model.eval() npz = np.load(opt.real_stat_path) results = [] for data_i in dataloader: model.set_input(data_i) break for config in tqdm.tqdm(configs): qualified = True macs, _ = model.profile(config) if macs > opt.budget: qualified = False else: qualified = True fakes, names = [], [] if qualified: for i, data_i in enumerate(dataloader): model.set_input(data_i) model.test(config) fakes.append(model.fake_B.cpu()) for path in model.get_image_paths(): short_path = ntpath.basename(path) name = os.path.splitext(short_path)[0] names.append(name) result = {'config_str': encode_config(config), 'macs': macs} if not opt.no_fid: if qualified: fid = get_fid(fakes, inception_model, npz, device, opt.batch_size, use_tqdm=False) result['fid'] = fid else: result['fid'] = 1e9 if 'cityscapes' in opt.dataroot and opt.direction == 'BtoA': if qualified: mIoU = get_cityscapes_mIoU(fakes, names, drn_model, device, data_dir=opt.cityscapes_path, batch_size=opt.batch_size, num_workers=opt.num_threads, use_tqdm=False) result['mIoU'] = mIoU else: result['mIoU'] = mIoU print(result, flush=True) results.append(result) queue.put(results)
class CycleGANModel(BaseModel): """ This class implements the CycleGAN model, for learning image-to-image translation without paired data. The model training requires '--dataset_mode unaligned' dataset. By default, it uses a '--netG resnet_9blocks' ResNet generator, a '--netD basic' discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective ('--gan_mode lsgan'). CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses. A (source domain), B (target domain). Generators: G_A: A -> B; G_B: B -> A. Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A. Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper) Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper) Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper) Dropout is not used in the original CycleGAN paper. """ assert is_train parser = super(CycleGANModel, CycleGANModel).modify_commandline_options( parser, is_train) parser.add_argument('--restore_G_A_path', type=str, default=None, help='the path to restore the generator G_A') parser.add_argument('--restore_D_A_path', type=str, default=None, help='the path to restore the discriminator D_A') parser.add_argument('--restore_G_B_path', type=str, default=None, help='the path to restore the generator G_B') parser.add_argument('--restore_D_B_path', type=str, default=None, help='the path to restore the discriminator D_B') parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument( '--lambda_identity', type=float, default=0.5, help='use identity mapping. ' 'Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. ' 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1' ) parser.add_argument( '--real_stat_A_path', type=str, required=True, help= 'the path to load the ground-truth A images information to compute FID.' ) parser.add_argument( '--real_stat_B_path', type=str, required=True, help= 'the path to load the ground-truth B images information to compute FID.' ) parser.set_defaults(norm='instance', dataset_mode='unaligned', batch_size=1, ndf=64, gan_mode='lsgan', nepochs=100, nepochs_decay=100, save_epoch_freq=20) return parser def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert opt.isTrain assert opt.direction == 'AtoB' assert opt.dataset_mode == 'unaligned' super(CycleGANModel, self).__init__(opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = [ 'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B', 'G_idt_B' ] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.netG, input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.ngf, norm=opt.norm, dropout_rate=opt.dropout_rate, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netG_B = networks.define_G(opt.netG, input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.ngf, norm=opt.norm, dropout_rate=opt.dropout_rate, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netD_A = networks.define_D(opt.netD, input_nc=opt.output_nc, ndf=opt.ndf, n_layers_D=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netD_B = networks.define_D(opt.netD, input_nc=opt.input_nc, ndf=opt.ndf, n_layers_D=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert (opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images self.fake_B_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images # define loss functions self.criterionGAN = GANLoss(opt.gan_mode).to( self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader_AtoB = create_eval_dataloader(self.opt, direction='AtoB') self.eval_dataloader_BtoA = create_eval_dataloader(self.opt, direction='BtoA') block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: self.drn_model.to(self.device) self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids) self.drn_model.eval() self.best_fid_A, self.best_fid_B = 1e9, 1e9 self.best_mIoU = -1e9 self.fids_A, self.fids_B = [], [] self.mIoUs = [] self.is_best = False self.npz_A = np.load(opt.real_stat_A_path) self.npz_B = np.load(opt.real_stat_B_path) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap domain A and domain B. """ # Since it is a cycle. self.real_A = input['A'].to(self.device) self.real_B = input['B'].to(self.device) def set_single_input(self, input): self.real_A = input['A'].to(self.device) self.image_paths = input['A_paths'] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG_A(self.real_A) # G_A(A) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.fake_A = self.netG_B(self.real_B) # G_B(B) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): """Calculate the loss for generators G_A and G_B""" lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed: ||G_A(B) - B|| self.idt_A = self.netG_A(self.real_B) self.loss_G_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed: ||G_B(A) - A|| self.idt_B = self.netG_B(self.real_A) self.loss_G_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_G_idt_A = 0 self.loss_G_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True, for_discriminator=False) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True, for_discriminator=False) # Forward cycle loss || G_B(G_A(A)) - A|| self.loss_G_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| self.loss_G_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_G_cycle_A + self.loss_G_cycle_B + self.loss_G_idt_A + self.loss_G_idt_B self.loss_G.backward() def optimize_parameters(self, steps): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G_A and G_B self.set_requires_grad( [self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero self.backward_D_A() # calculate gradients for D_A self.backward_D_B() # calculate gradients for D_B self.optimizer_D.step() # update D_A and D_B's weights def test_single_side(self, direction): generator = getattr(self, 'netG_%s' % direction[0]) with torch.no_grad(): self.fake_B = generator(self.real_A) def evaluate_model(self, step): ret = {} self.is_best = False save_dir = os.path.join(self.opt.log_dir, 'eval', str(step)) os.makedirs(save_dir, exist_ok=True) self.netG_A.eval() self.netG_B.eval() for direction in ['AtoB', 'BtoA']: eval_dataloader = getattr(self, 'eval_dataloader_' + direction) fakes, names = [], [] cnt = 0 for i, data_i in enumerate( tqdm(eval_dataloader, desc='Eval %s ' % direction, position=2, leave=False)): self.set_single_input(data_i) self.test_single_side(direction) fakes.append(self.fake_B.cpu()) for j in range(len(self.image_paths)): short_path = ntpath.basename(self.image_paths[j]) name = os.path.splitext(short_path)[0] names.append(name) if cnt < 10: input_im = util.tensor2im(self.real_A[j]) fake_im = util.tensor2im(self.fake_B[j]) util.save_image(input_im, os.path.join(save_dir, direction, 'input', '%s.png' % name), create_dir=True) util.save_image(fake_im, os.path.join(save_dir, direction, 'fake', '%s.png' % name), create_dir=True) cnt += 1 suffix = direction[-1] fid = get_fid(fakes, self.inception_model, getattr(self, 'npz_%s' % direction[-1]), device=self.device, batch_size=self.opt.eval_batch_size, tqdm_position=2) if fid < getattr(self, 'best_fid_%s' % suffix): self.is_best = True setattr(self, 'best_fid_%s' % suffix, fid) fids = getattr(self, 'fids_%s' % suffix) fids.append(fid) if len(fids) > 3: fids.pop(0) ret['metric/fid_%s' % suffix] = fid ret['metric/fid_%s-mean' % suffix] = sum(getattr(self, 'fids_%s' % suffix)) / len( getattr(self, 'fids_%s' % suffix)) ret['metric/fid_%s-best' % suffix] = getattr( self, 'best_fid_%s' % suffix) if 'cityscapes' in self.opt.dataroot and direction == 'BtoA': mIoU = get_cityscapes_mIoU(fakes, names, self.drn_model, self.device, table_path=self.opt.table_path, data_dir=self.opt.cityscapes_path, batch_size=self.opt.eval_batch_size, num_workers=self.opt.num_threads, tqdm_position=2) if mIoU > self.best_mIoU: self.is_best = True self.best_mIoU = mIoU self.mIoUs.append(mIoU) if len(self.mIoUs) > 3: self.mIoUs = self.mIoUs[1:] ret['metric/mIoU'] = mIoU ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs) ret['metric/mIoU-best'] = self.best_mIoU self.netG_A.train() self.netG_B.train() return ret
class BaseSPADEDistiller(SPADEModel): @staticmethod def modify_commandline_options(parser, is_train): assert isinstance(parser, argparse.ArgumentParser) parser.add_argument( '--separable_conv_norm', type=str, default='instance', choices=('none', 'instance', 'batch'), help='whether to use instance norm for the separable convolutions') parser.add_argument( '--num_upsampling_layers', choices=('normal', 'more', 'most'), default='more', help= "If 'more', adds upsampling layer between the two middle resnet blocks. " "If 'most', also add one more upsampling + resnet layer at the end of the generator" ) parser.add_argument('--teacher_netG', type=str, default='mobile_spade', help='specify teacher generator architecture', choices=[ 'spade', 'mobile_spade', 'super_mobile_spade', 'sub_mobile_spade' ]) parser.add_argument('--student_netG', type=str, default='mobile_spade', help='specify student generator architecture', choices=[ 'spade', 'mobile_spade', 'super_mobile_spade', 'sub_mobile_spade' ]) parser.add_argument( '--teacher_ngf', type=int, default=64, help='the base number of filters of the teacher generator') parser.add_argument( '--student_ngf', type=int, default=48, help='the base number of filters of the student generator') parser.add_argument( '--teacher_norm_G', type=str, default='spadesyncbatch3x3', help= 'instance normalization or batch normalization of the teacher model' ) parser.add_argument( '--student_norm_G', type=str, default='spadesyncbatch3x3', help= 'instance normalization or batch normalization of the student model' ) parser.add_argument('--restore_teacher_G_path', type=str, required=True, help='the path to restore the teacher generator') parser.add_argument('--restore_student_G_path', type=str, default=None, help='the path to restore the student generator') parser.add_argument( '--restore_A_path', type=str, default=None, help='the path to restore the adaptors for distillation') parser.add_argument('--restore_D_path', type=str, default=None, help='the path to restore the discriminator') parser.add_argument('--restore_O_path', type=str, default=None, help='the path to restore the optimizer') parser.add_argument('--lambda_gan', type=float, default=1, help='weight for gan loss') parser.add_argument('--lambda_feat', type=float, default=10, help='weight for gan feature loss') parser.add_argument('--lambda_vgg', type=float, default=10, help='weight for vgg loss') parser.add_argument('--lambda_distill', type=float, default=10, help='weight for vgg loss') parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') parser.add_argument('--no_fid', action='store_true', help='No FID evaluation during training') parser.add_argument('--no_mIoU', action='store_true', help='No mIoU evaluation during training ' '(sometimes because there are CUDA memory)') parser.set_defaults(netD='multi_scale', ndf=64, dataset_mode='cityscapes', batch_size=16, print_freq=50, save_latest_freq=10000000000, save_epoch_freq=10, nepochs=100, nepochs_decay=100, init_type='xavier') return parser def __init__(self, opt): super(SPADEModel, self).__init__(opt) self.model_names = ['G_student', 'G_teacher', 'D'] self.visual_names = ['labels', 'Tfake_B', 'Sfake_B', 'real_B'] self.model_names.append('D') self.loss_names = [ 'G_gan', 'G_feat', 'G_vgg', 'G_distill', 'D_real', 'D_fake' ] if hasattr(opt, 'distiller'): self.modules = SPADEDistillerModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules else: self.modules = SPADESupernetModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules for i in range(len(self.modules_on_one_gpu.mapping_layers)): self.loss_names.append('G_distill%d' % i) self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers( ) self.optimizers = [self.optimizer_G, self.optimizer_D] if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot and not opt.no_mIoU: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) self.drn_model.to(self.device) self.drn_model.eval() self.eval_dataloader = create_eval_dataloader(self.opt) self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.npz = np.load(opt.real_stat_path) def forward(self, on_one_gpu=False, config=None): if config is not None: self.modules_on_one_gpu.config = config if on_one_gpu: self.Tfake_B, self.Sfake_B = self.modules_on_one_gpu( self.input_semantics) else: self.Tfake_B, self.Sfake_B = self.modules(self.input_semantics) def load_networks(self, verbose=True): self.modules_on_one_gpu.load_networks(verbose) if self.opt.restore_O_path is not None: for i, optimizer in enumerate(self.optimizers): path = '%s-%d.pth' % (self.opt.restore_O_path, i) util.load_optimizer(optimizer, path, verbose) if self.opt.no_TTUR: G_lr, D_lr = self.opt.lr, self.opt.lr else: G_lr, D_lr = self.opt.lr / 2, self.opt.lr * 2 for param_group in self.optimizer_G.param_groups: param_group['lr'] = G_lr for param_group in self.optimizer_D.param_groups: param_group['lr'] = D_lr def save_networks(self, epoch): self.modules_on_one_gpu.save_networks(epoch, self.save_dir) for i, optimizer in enumerate(self.optimizers): save_filename = '%s_optim-%d.pth' % (epoch, i) save_path = os.path.join(self.save_dir, save_filename) torch.save(optimizer.state_dict(), save_path) def evaluate_model(self, step): raise NotImplementedError def optimize_parameters(self, steps): self.set_requires_grad(self.modules_on_one_gpu.netD, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() self.set_requires_grad(self.modules_on_one_gpu.netD, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step()