def test_pix2pix_mIoU(model, opt): opt.phase = 'val' opt.num_threads = 0 opt.batch_size = 1 opt.serial_batches = True opt.no_flip = True opt.load_size = 256 opt.display_id = -1 dataset = create_dataset(opt) model.model_eval() result_dir = os.path.join(opt.checkpoints_dir, opt.name, 'test_results') util.mkdirs(result_dir) fake_B = {} names = [] for i, data in enumerate(dataset): model.set_input(data) with torch.no_grad(): model.forward() visuals = model.get_current_visuals() fake_B[data['A_paths'][0]] = visuals['fake_B'] for path in range(len(model.image_paths)): short_path = ntpath.basename(model.image_paths[0][0]) name = os.path.splitext(short_path)[0] if name not in names: names.append(name) util.save_images(visuals, model.image_paths, result_dir, direction=opt.direction, aspect_ratio=opt.aspect_ratio) drn_model = DRNSeg('drn_d_105', 19, pretrained=False).to(model.device) util.load_network(drn_model, opt.drn_path, verbose=False) drn_model.eval() mIoU = get_mIoU(list(fake_B.values()), names, drn_model, model.device, table_path=os.path.join(opt.dataroot, 'table.txt'), data_dir=opt.dataroot, batch_size=opt.batch_size, num_workers=opt.num_threads) return mIoU
class Pix2PixModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): assert is_train parser = super(Pix2PixModel, Pix2PixModel).modify_commandline_options( parser, is_train) parser.add_argument('--restore_G_path', type=str, default=None, help='the path to restore the generator') parser.add_argument('--restore_D_path', type=str, default=None, help='the path to restore the discriminator') parser.add_argument('--recon_loss_type', type=str, default='l1', choices=['l1', 'l2', 'smooth_l1'], help='the type of the reconstruction loss') parser.add_argument('--lambda_recon', type=float, default=100, help='weight for reconstruction loss') parser.add_argument('--lambda_gan', type=float, default=1, help='weight for gan loss') parser.add_argument( '--real_stat_path', type=str, required=True, help= 'the path to load the groud-truth images information to compute FID.' ) return parser def __init__(self, opt): """Initialize the pix2pix class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert opt.isTrain BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['G_gan', 'G_recon', 'D_real', 'D_fake'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks> self.model_names = ['G', 'D'] # define networks (both generator and discriminator) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) # define loss functions self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) if opt.recon_loss_type == 'l1': self.criterionRecon = torch.nn.L1Loss() elif opt.recon_loss_type == 'l2': self.criterionRecon = torch.nn.MSELoss() elif opt.recon_loss_type == 'smooth_l1': self.criterionRecon = torch.nn.SmoothL1Loss() else: raise NotImplementedError( 'Unknown reconstruction loss type [%s]!' % opt.loss_type) # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader = create_eval_dataloader(self.opt) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: self.drn_model.to(self.device) self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids) self.drn_model.eval() self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.Tacts, self.Sacts = {}, {} self.npz = np.load(opt.real_stat_path) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap images in domain A and domain B. """ AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG(self.real_A) # G(A) def backward_D(self): """Calculate GAN loss for the discriminator""" fake_AB = torch.cat((self.real_A, self.fake_B), 1).detach() real_AB = torch.cat((self.real_A, self.real_B), 1).detach() pred_fake = self.netD(fake_AB) self.loss_D_fake = self.criterionGAN(pred_fake, False, for_discriminator=True) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True, for_discriminator=True) self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): """Calculate GAN and L1 loss for the generator""" # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_gan = self.criterionGAN( pred_fake, True, for_discriminator=False) * self.opt.lambda_gan # Second, G(A) = B self.loss_G_recon = self.criterionRecon( self.fake_B, self.real_B) * self.opt.lambda_recon # combine loss and calculate gradients self.loss_G = self.loss_G_gan + self.loss_G_recon self.loss_G.backward() def optimize_parameters(self, steps): self.forward() # compute fake images: G(A) # update D self.set_requires_grad(self.netD, True) # enable backprop for D self.optimizer_D.zero_grad() # set D's gradients to zero self.backward_D() # calculate gradients for D self.optimizer_D.step() # update D's weights # update G self.set_requires_grad( self.netD, False) # D requires no gradients when optimizing G self.optimizer_G.zero_grad() # set G's gradients to zero self.backward_G() # calculate graidents for G self.optimizer_G.step() # udpate G's weights def evaluate_model(self, step): self.is_best = False save_dir = os.path.join(self.opt.log_dir, 'eval', str(step)) os.makedirs(save_dir, exist_ok=True) self.netG.eval() fakes, names = [], [] cnt = 0 for i, data_i in enumerate(tqdm(self.eval_dataloader)): self.set_input(data_i) self.test() fakes.append(self.fake_B.cpu()) for j in range(len(self.image_paths)): short_path = ntpath.basename(self.image_paths[j]) name = os.path.splitext(short_path)[0] names.append(name) if cnt < 10: input_im = util.tensor2im(self.real_A[j]) real_im = util.tensor2im(self.real_B[j]) fake_im = util.tensor2im(self.fake_B[j]) util.save_image(input_im, os.path.join(save_dir, 'input', '%s.png' % name), create_dir=True) util.save_image(real_im, os.path.join(save_dir, 'real', '%s.png' % name), create_dir=True) util.save_image(fake_im, os.path.join(save_dir, 'fake', '%s.png' % name), create_dir=True) cnt += 1 fid = get_fid(fakes, self.inception_model, self.npz, device=self.device, batch_size=self.opt.eval_batch_size) if fid < self.best_fid: self.is_best = True self.best_fid = fid self.fids.append(fid) if len(self.fids) > 3: self.fids.pop(0) ret = { 'metric/fid': fid, 'metric/fid-mean': sum(self.fids) / len(self.fids), 'metric/fid-best': self.best_fid } if 'cityscapes' in self.opt.dataroot: mIoU = get_mIoU(fakes, names, self.drn_model, self.device, table_path=self.opt.table_path, data_dir=self.opt.cityscapes_path, batch_size=self.opt.eval_batch_size, num_workers=self.opt.num_threads) if mIoU > self.best_mIoU: self.is_best = True self.best_mIoU = mIoU self.mIoUs.append(mIoU) if len(self.mIoUs) > 3: self.mIoUs = self.mIoUs[1:] ret['metric/mIoU'] = mIoU ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs) ret['metric/mIoU-best'] = self.best_mIoU self.netG.train() return ret
class CycleGANModel(BaseModel): """ This class implements the CycleGAN model, for learning image-to-image translation without paired data. The model training requires '--dataset_mode unaligned' dataset. By default, it uses a '--netG resnet_9blocks' ResNet generator, a '--netD basic' discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective ('--gan_mode lsgan'). CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses. A (source domain), B (target domain). Generators: G_A: A -> B; G_B: B -> A. Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A. Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper) Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper) Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper) Dropout is not used in the original CycleGAN paper. """ assert is_train parser = super(CycleGANModel, CycleGANModel).modify_commandline_options( parser, is_train) parser.add_argument('--restore_G_A_path', type=str, default=None, help='the path to restore the generator G_A') parser.add_argument('--restore_D_A_path', type=str, default=None, help='the path to restore the discriminator D_A') parser.add_argument('--restore_G_B_path', type=str, default=None, help='the path to restore the generator G_B') parser.add_argument('--restore_D_B_path', type=str, default=None, help='the path to restore the discriminator D_B') parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument( '--lambda_identity', type=float, default=0.5, help='use identity mapping. ' 'Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. ' 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1' ) parser.add_argument( '--real_stat_A_path', type=str, required=True, help= 'the path to load the ground-truth A images information to compute FID.' ) parser.add_argument( '--real_stat_B_path', type=str, required=True, help= 'the path to load the ground-truth B images information to compute FID.' ) parser.set_defaults(norm='instance', dataset_mode='unaligned', batch_size=1, ndf=64, gan_mode='lsgan', nepochs=100, nepochs_decay=100, save_epoch_freq=20) return parser def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert opt.isTrain assert opt.direction == 'AtoB' assert opt.dataset_mode == 'unaligned' BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = [ 'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B', 'G_idt_B' ] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert (opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images self.fake_B_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images # define loss functions self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to( self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader_AtoB = create_eval_dataloader(self.opt, direction='AtoB') self.eval_dataloader_BtoA = create_eval_dataloader(self.opt, direction='BtoA') block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids) self.drn_model.eval() self.best_fid_A, self.best_fid_B = 1e9, 1e9 self.best_mIoU = -1e9 self.fids_A, self.fids_B = [], [] self.mIoUs = [] self.is_best = False self.npz_A = np.load(opt.real_stat_A_path) self.npz_B = np.load(opt.real_stat_B_path) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap domain A and domain B. """ # Since it is a cycle. self.real_A = input['A'].to(self.device) self.real_B = input['B'].to(self.device) def set_single_input(self, input): self.real_A = input['A'].to(self.device) self.image_paths = input['A_paths'] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG_A(self.real_A) # G_A(A) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.fake_A = self.netG_B(self.real_B) # G_B(B) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): """Calculate the loss for generators G_A and G_B""" lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed: ||G_A(B) - B|| self.idt_A = self.netG_A(self.real_B) self.loss_G_idt_A = self.criterionIdt( self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed: ||G_B(A) - A|| self.idt_B = self.netG_B(self.real_A) self.loss_G_idt_B = self.criterionIdt( self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_G_idt_A = 0 self.loss_G_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss || G_B(G_A(A)) - A|| self.loss_G_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| self.loss_G_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_G_cycle_A + self.loss_G_cycle_B + self.loss_G_idt_A + self.loss_G_idt_B self.loss_G.backward() def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G_A and G_B self.set_requires_grad( [self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero self.backward_D_A() # calculate gradients for D_A self.backward_D_B() # calculate graidents for D_B self.optimizer_D.step() # update D_A and D_B's weights def test_single_side(self, direction): generator = getattr(self, 'netG_%s' % direction[0]) with torch.no_grad(): self.fake_B = generator(self.real_A) def evaluate_model(self, step): ret = {} self.is_best = False save_dir = os.path.join(self.opt.log_dir, 'eval', str(step)) os.makedirs(save_dir, exist_ok=True) self.netG_A.eval() self.netG_B.eval() for direction in ['AtoB', 'BtoA']: eval_dataloader = getattr(self, 'eval_dataloader_' + direction) fakes, names = [], [] cnt = 0 # print(len(eval_dataset)) for i, data_i in enumerate(tqdm(eval_dataloader)): self.set_single_input(data_i) self.test_single_side(direction) # print(self.image_paths) fakes.append(self.fake_B.cpu()) for j in range(len(self.image_paths)): short_path = ntpath.basename(self.image_paths[j]) name = os.path.splitext(short_path)[0] names.append(name) if cnt < 10: input_im = util.tensor2im(self.real_A[j]) fake_im = util.tensor2im(self.fake_B[j]) util.save_image(input_im, os.path.join(save_dir, direction, 'input', '%s.png' % name), create_dir=True) util.save_image(fake_im, os.path.join(save_dir, direction, 'fake', '%s.png' % name), create_dir=True) cnt += 1 suffix = direction[-1] fid = get_fid(fakes, self.inception_model, getattr(self, 'npz_%s' % direction[-1]), device=self.device, batch_size=self.opt.eval_batch_size) if fid < getattr(self, 'best_fid_%s' % suffix): self.is_best = True setattr(self, 'best_fid_%s' % suffix, fid) fids = getattr(self, 'fids_%s' % suffix) fids.append(fid) if len(fids) > 3: fids.pop(0) ret['metric/fid_%s' % suffix] = fid ret['metric/fid_%s-mean' % suffix] = sum(getattr(self, 'fids_%s' % suffix)) / len( getattr(self, 'fids_%s' % suffix)) ret['metric/fid_%s-best' % suffix] = getattr( self, 'best_fid_%s' % suffix) if 'cityscapes' in self.opt.dataroot and direction == 'BtoA': mIoU = get_mIoU(fakes, names, self.drn_model, self.device, table_path=self.opt.table_path, data_dir=self.opt.cityscapes_path, batch_size=self.opt.eval_batch_size, num_workers=self.opt.num_threads) if mIoU > self.best_mIoU: self.is_best = True self.best_mIoU = mIoU self.mIoUs.append(mIoU) if len(self.mIoUs) > 3: self.mIoUs = self.mIoUs[1:] ret['metric/mIoU'] = mIoU ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs) ret['metric/mIoU-best'] = self.best_mIoU self.netG_A.train() self.netG_B.train() return ret
dataloader = create_dataloader(opt) model = create_model(opt) model.setup(opt) device = model.device if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] inception_model = InceptionV3([block_idx]) inception_model.to(device) inception_model.eval() if 'cityscapes' in opt.dataroot and opt.direction == 'BtoA': drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: drn_model = nn.DataParallel(drn_model, opt.gpu_ids) drn_model.eval() npz = np.load(opt.real_stat_path) results = [] for data_i in dataloader: model.set_input(data_i) break for config in tqdm.tqdm(configs): qualified = True macs, _ = model.profile(config) if macs > opt.budget: qualified = False else: qualified = True
class SPADEModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train): assert isinstance(parser, argparse.ArgumentParser) parser.set_defaults(netG='inception_spade') parser.add_argument( '--norm_G', type=str, default='spadesyncbatch3x3', help='instance normalization or batch normalization') parser.add_argument( '--num_upsampling_layers', choices=('normal', 'more', 'most'), default='more', help= "If 'more', adds upsampling layer between the two middle resnet blocks. " "If 'most', also add one more upsampling + resnet layer at the end of the generator" ) if is_train: parser.add_argument('--restore_G_path', type=str, default=None, help='the path to restore the generator') parser.add_argument('--restore_D_path', type=str, default=None, help='the path to restore the discriminator') parser.add_argument( '--real_stat_path', type=str, required=True, help= 'the path to load the groud-truth images information to compute FID.' ) parser.add_argument('--lambda_gan', type=float, default=1, help='weight for gan loss') parser.add_argument('--lambda_feat', type=float, default=10, help='weight for gan feature loss') parser.add_argument('--lambda_vgg', type=float, default=10, help='weight for vgg loss') parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') parser.add_argument('--no_fid', action='store_true', help='No FID evaluation during training') parser.add_argument('--no_mIoU', action='store_true', help='No mIoU evaluation during training ' '(sometimes because there are CUDA memory)') parser.set_defaults(netD='multi_scale', ndf=64, dataset_mode='cityscapes', batch_size=16, print_freq=50, save_latest_freq=10000000000, save_epoch_freq=10, nepochs=100, nepochs_decay=100, init_type='xavier', active_fn='nn.LeakyReLU') parser = networks.modify_commandline_options(parser, is_train) return parser def __init__(self, opt): super(SPADEModel, self).__init__(opt) self.model_names = ['G'] self.visual_names = ['labels', 'fake_B', 'real_B'] self.modules = SPADEModelModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules if opt.isTrain: self.model_names.append('D') self.loss_names = ['G_gan', 'G_feat', 'G_vgg', 'D_real', 'D_fake'] self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers( ) self.optimizers = [self.optimizer_G, self.optimizer_D] if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot and not opt.no_mIoU: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) self.drn_model.to(self.device) self.drn_model.eval() self.eval_dataloader = create_eval_dataloader(self.opt) self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.npz = np.load(opt.real_stat_path) else: self.modules.eval() def set_input(self, input): self.data = input self.image_paths = input['path'] self.labels = input['label'].to(self.device) self.input_semantics, self.real_B = self.preprocess_input(input) def test(self): with torch.no_grad(): self.forward(on_one_gpu=True) def preprocess_input(self, data): data['label'] = data['label'].long() data['label'] = data['label'].to(self.device) data['instance'] = data['instance'].to(self.device) data['image'] = data['image'].to(self.device) label_map = data['label'] bs, _, h, w = label_map.size() nc = self.opt.input_nc + 1 if self.opt.contain_dontcare_label \ else self.opt.input_nc input_label = torch.zeros([bs, nc, h, w], device=self.device) input_semantics = input_label.scatter_(1, label_map, 1.0) if not self.opt.no_instance: inst_map = data['instance'] instance_edge_map = self.get_edges(inst_map) input_semantics = torch.cat((input_semantics, instance_edge_map), dim=1) return input_semantics, data['image'] def forward(self, on_one_gpu=False): if on_one_gpu: self.fake_B = self.modules_on_one_gpu(self.input_semantics) else: self.fake_B = self.modules(self.input_semantics) def get_edges(self, t): edge = torch.zeros(t.size(), dtype=torch.uint8, device=self.device) edge[:, :, :, 1:] = edge[:, :, :, 1:] | ( (t[:, :, :, 1:] != t[:, :, :, :-1]).byte()) edge[:, :, :, :-1] = edge[:, :, :, :-1] | ( (t[:, :, :, 1:] != t[:, :, :, :-1]).byte()) edge[:, :, 1:, :] = edge[:, :, 1:, :] | ( (t[:, :, 1:, :] != t[:, :, :-1, :]).byte()) edge[:, :, :-1, :] = edge[:, :, :-1, :] | ( (t[:, :, 1:, :] != t[:, :, :-1, :]).byte()) return edge.float() def profile(self, verbose=True): macs, params = self.modules_on_one_gpu.profile( self.input_semantics[:1]) if verbose: print('MACs: %.3fG\tParams: %.3fM' % (macs / 1e9, params / 1e6), flush=True) return macs, params def backward_G(self): losses = self.modules(self.input_semantics, self.real_B, mode='G_loss') loss_G = losses['loss_G'].mean() for loss_name in self.loss_names: if loss_name.startswith('G'): setattr(self, 'loss_%s' % loss_name, losses[loss_name].detach().mean()) loss_G.backward() def backward_D(self): losses = self.modules(self.input_semantics, self.real_B, mode='D_loss') loss_D = losses['loss_D'].mean() for loss_name in self.loss_names: if loss_name.startswith('D'): setattr(self, 'loss_%s' % loss_name, losses[loss_name].detach().mean()) loss_D.backward() def optimize_parameters(self, steps): self.set_requires_grad(self.modules_on_one_gpu.netD, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() self.set_requires_grad(self.modules_on_one_gpu.netD, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() def evaluate_model(self, step, save_image=False): self.is_best = False save_dir = os.path.join(self.opt.log_dir, 'eval', str(step)) os.makedirs(save_dir, exist_ok=True) self.modules_on_one_gpu.netG.eval() torch.cuda.empty_cache() fakes, names = [], [] ret = {} cnt = 0 for i, data_i in enumerate(tqdm(self.eval_dataloader)): self.set_input(data_i) self.test() fakes.append(self.fake_B.cpu()) for j in range(len(self.image_paths)): short_path = ntpath.basename(self.image_paths[j]) name = os.path.splitext(short_path)[0] names.append(name) if cnt < 10 or save_image: input_im = util.tensor2label(self.input_semantics[j], self.opt.input_nc + 2) real_im = util.tensor2im(self.real_B[j]) fake_im = util.tensor2im(self.fake_B[j]) util.save_image(input_im, os.path.join(save_dir, 'input', '%s.png' % name), create_dir=True) util.save_image(real_im, os.path.join(save_dir, 'real', '%s.png' % name), create_dir=True) util.save_image(fake_im, os.path.join(save_dir, 'fake', '%s.png' % name), create_dir=True) cnt += 1 if not self.opt.no_fid: fid = get_fid(fakes, self.inception_model, self.npz, device=self.device, batch_size=self.opt.eval_batch_size) if fid < self.best_fid: self.is_best = True self.best_fid = fid self.fids.append(fid) if len(self.fids) > 3: self.fids.pop(0) ret['metric/fid'] = fid ret['metric/fid-mean'] = sum(self.fids) / len(self.fids) ret['metric/fid-best'] = self.best_fid if 'cityscapes' in self.opt.dataroot and not self.opt.no_mIoU: mIoU = get_mIoU(fakes, names, self.drn_model, self.device, table_path=self.opt.table_path, data_dir=self.opt.cityscapes_path, batch_size=self.opt.eval_batch_size, num_workers=self.opt.num_threads) if mIoU > self.best_mIoU: self.is_best = True self.best_mIoU = mIoU self.mIoUs.append(mIoU) if len(self.mIoUs) > 3: self.mIoUs = self.mIoUs[1:] ret['metric/mIoU'] = mIoU ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs) ret['metric/mIoU-best'] = self.best_mIoU self.modules_on_one_gpu.netG.train() torch.cuda.empty_cache() return ret def print_networks(self): print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self.modules_on_one_gpu, 'net' + name) num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) if hasattr(self.opt, 'log_dir'): with open( os.path.join(self.opt.log_dir, 'net' + name + '.txt'), 'w') as f: f.write(str(net) + '\n') f.write( '[Network %s] Total number of parameters : %.3f M\n' % (name, num_params / 1e6)) print('-----------------------------------------------') def load_networks(self, verbose=True, teacher_only=False, restore_pretrain=True): self.modules_on_one_gpu.load_networks(verbose) if self.isTrain and self.opt.restore_O_path is not None: for i, optimizer in enumerate(self.optimizers): path = '%s-%d.pth' % (self.opt.restore_O_path, i) util.load_optimizer(optimizer, path, verbose) if self.opt.no_TTUR: G_lr, D_lr = self.opt.lr, self.opt.lr else: G_lr, D_lr = self.opt.lr / 2, self.opt.lr * 2 for param_group in self.optimizer_G.param_groups: param_group['lr'] = G_lr for param_group in self.optimizer_D.param_groups: param_group['lr'] = D_lr def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str) and hasattr(self, name): visual_ret[name] = getattr(self, name) return visual_ret def save_networks(self, epoch): self.modules_on_one_gpu.save_networks(epoch, self.save_dir) for i, optimizer in enumerate(self.optimizers): save_filename = '%s_optim-%d.pth' % (epoch, i) save_path = os.path.join(self.save_dir, save_filename) torch.save(optimizer.state_dict(), save_path)
class BaseResnetDistiller(BaseModel): @staticmethod def modify_commandline_options(parser, is_train): assert is_train parser = super(BaseResnetDistiller, BaseResnetDistiller).modify_commandline_options(parser, is_train) parser.add_argument('--teacher_netG', type=str, default='mobile_resnet_9blocks', help='specify teacher generator architecture', choices=['resnet_9blocks', 'mobile_resnet_9blocks', 'super_mobile_resnet_9blocks', 'sub_mobile_resnet_9blocks']) parser.add_argument('--student_netG', type=str, default='mobile_resnet_9blocks', help='specify student generator architecture', choices=['resnet_9blocks', 'mobile_resnet_9blocks', 'super_mobile_resnet_9blocks', 'sub_mobile_resnet_9blocks']) parser.add_argument('--teacher_ngf', type=int, default=64, help='the base number of filters of the teacher generator') parser.add_argument('--student_ngf', type=int, default=48, help='the base number of filters of the student generator') parser.add_argument('--restore_teacher_G_path', type=str, required=True, help='the path to restore the teacher generator') parser.add_argument('--restore_student_G_path', type=str, default=None, help='the path to restore the student generator') parser.add_argument('--restore_A_path', type=str, default=None, help='the path to restore the adaptors for distillation') parser.add_argument('--restore_D_path', type=str, default=None, help='the path to restore the discriminator') parser.add_argument('--restore_O_path', type=str, default=None, help='the path to restore the optimizer') parser.add_argument('--recon_loss_type', type=str, default='l1', choices=['l1', 'l2', 'smooth_l1', 'vgg'], help='the type of the reconstruction loss') parser.add_argument('--lambda_distill', type=float, default=1, help='weights for the intermediate activation distillation loss') parser.add_argument('--lambda_recon', type=float, default=100, help='weights for the reconstruction loss.') parser.add_argument('--lambda_gan', type=float, default=1, help='weight for gan loss') parser.add_argument('--teacher_dropout_rate', type=float, default=0) parser.add_argument('--student_dropout_rate', type=float, default=0) return parser def __init__(self, opt): assert opt.isTrain super(BaseResnetDistiller, self).__init__(opt) self.loss_names = ['G_gan', 'G_distill', 'G_recon', 'D_fake', 'D_real'] self.optimizers = [] self.image_paths = [] self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B'] self.model_names = ['netG_student', 'netG_teacher', 'netD'] self.netG_teacher = networks.define_G(opt.input_nc, opt.output_nc, opt.teacher_ngf, opt.teacher_netG, opt.norm, opt.teacher_dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.netG_student = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf, opt.student_netG, opt.norm, opt.student_dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) if getattr(opt, 'sort_channels', False) and opt.restore_student_G_path is not None: self.netG_student_tmp = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf, opt.student_netG.replace('super_', ''), opt.norm, opt.student_dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) if hasattr(opt, 'distiller'): self.netG_pretrained = networks.define_G(opt.input_nc, opt.output_nc, opt.pretrained_ngf, opt.pretrained_netG, opt.norm, 0, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) if opt.dataset_mode == 'aligned': self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) elif opt.dataset_mode == 'unaligned': self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) else: raise NotImplementedError('Unknown dataset mode [%s]!!!' % opt.dataset_mode) self.netG_teacher.eval() self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to(self.device) if opt.recon_loss_type == 'l1': self.criterionRecon = torch.nn.L1Loss() elif opt.recon_loss_type == 'l2': self.criterionRecon = torch.nn.MSELoss() elif opt.recon_loss_type == 'smooth_l1': self.criterionRecon = torch.nn.SmoothL1Loss() elif opt.recon_loss_type == 'vgg': self.criterionRecon = models.modules.loss.VGGLoss(self.device) else: raise NotImplementedError('Unknown reconstruction loss type [%s]!' % opt.loss_type) if isinstance(self.netG_teacher, nn.DataParallel): self.mapping_layers = ['module.model.%d' % i for i in range(9, 21, 3)] else: self.mapping_layers = ['model.%d' % i for i in range(9, 21, 3)] self.netAs = [] self.Tacts, self.Sacts = {}, {} G_params = [self.netG_student.parameters()] for i, n in enumerate(self.mapping_layers): ft, fs = self.opt.teacher_ngf, self.opt.student_ngf if hasattr(opt, 'distiller'): netA = nn.Conv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \ to(self.device) else: netA = SuperConv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \ to(self.device) networks.init_net(netA) G_params.append(netA.parameters()) self.netAs.append(netA) self.loss_names.append('G_distill%d' % i) self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader = create_eval_dataloader(self.opt, direction=opt.direction) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: self.drn_model.to(self.device) self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids) self.drn_model.eval() self.npz = np.load(opt.real_stat_path) self.is_best = False def setup(self, opt, verbose=True): self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] self.load_networks(verbose) if verbose: self.print_networks() if self.opt.lambda_distill > 0: def get_activation(mem, name): def get_output_hook(module, input, output): mem[name] = output return get_output_hook def add_hook(net, mem, mapping_layers): for n, m in net.named_modules(): if n in mapping_layers: m.register_forward_hook(get_activation(mem, n)) add_hook(self.netG_teacher, self.Tacts, self.mapping_layers) add_hook(self.netG_student, self.Sacts, self.mapping_layers) def set_input(self, input): AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def set_single_input(self, input): self.real_A = input['A'].to(self.device) self.image_paths = input['A_paths'] def forward(self): raise NotImplementedError def backward_D(self): if self.opt.dataset_mode == 'aligned': fake = torch.cat((self.real_A, self.Sfake_B), 1).detach() real = torch.cat((self.real_A, self.real_B), 1).detach() else: fake = self.Sfake_B.detach() real = self.real_B.detach() pred_fake = self.netD(fake) self.loss_D_fake = self.criterionGAN(pred_fake, False, for_discriminator=True) pred_real = self.netD(real) self.loss_D_real = self.criterionGAN(pred_real, True, for_discriminator=True) self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def calc_distill_loss(self): raise NotImplementedError def backward_G(self): raise NotImplementedError def optimize_parameters(self, steps): raise NotImplementedError def print_networks(self): print('---------- Networks initialized -------------') for name in self.model_names: if hasattr(self, name): net = getattr(self, name) num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) with open(os.path.join(self.opt.log_dir, name + '.txt'), 'w') as f: f.write(str(net) + '\n') f.write('[Network %s] Total number of parameters : %.3f M\n' % (name, num_params / 1e6)) print('-----------------------------------------------') def load_networks(self, verbose=True): util.load_network(self.netG_teacher, self.opt.restore_teacher_G_path, verbose) if self.opt.restore_student_G_path is not None: util.load_network(self.netG_student, self.opt.restore_student_G_path, verbose) if hasattr(self, 'netG_student_tmp'): util.load_network(self.netG_student_tmp, self.opt.restore_student_G_path, verbose) if self.opt.restore_D_path is not None: util.load_network(self.netD, self.opt.restore_D_path, verbose) if self.opt.restore_A_path is not None: for i, netA in enumerate(self.netAs): path = '%s-%d.pth' % (self.opt.restore_A_path, i) util.load_network(netA, path, verbose) if self.opt.restore_O_path is not None: for i, optimizer in enumerate(self.optimizers): path = '%s-%d.pth' % (self.opt.restore_O_path, i) util.load_optimizer(optimizer, path, verbose) for param_group in optimizer.param_groups: param_group['lr'] = self.opt.lr def save_networks(self, epoch): def save_net(net, save_path): if len(self.gpu_ids) > 0 and torch.cuda.is_available(): if isinstance(net, DataParallel): torch.save(net.module.cpu().state_dict(), save_path) else: torch.save(net.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) save_filename = '%s_net_%s.pth' % (epoch, 'G') save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net%s_student' % 'G') save_net(net, save_path) save_filename = '%s_net_%s.pth' % (epoch, 'D') save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net%s' % 'D') save_net(net, save_path) for i, net in enumerate(self.netAs): save_filename = '%s_net_%s-%d.pth' % (epoch, 'A', i) save_path = os.path.join(self.save_dir, save_filename) save_net(net, save_path) for i, optimizer in enumerate(self.optimizers): save_filename = '%s_optim-%d.pth' % (epoch, i) save_path = os.path.join(self.save_dir, save_filename) torch.save(optimizer.state_dict(), save_path) def evaluate_model(self, step): raise NotImplementedError def test(self): with torch.no_grad(): self.forward()
class BaseSPADEDistiller(SPADEModel): @staticmethod def modify_commandline_options(parser, is_train): assert isinstance(parser, argparse.ArgumentParser) parser.add_argument( '--num_upsampling_layers', choices=('normal', 'more', 'most'), default='more', help= "If 'more', adds upsampling layer between the two middle resnet blocks. " "If 'most', also add one more upsampling + resnet layer at the end of the generator" ) parser.add_argument('--teacher_netG', type=str, default='inception_spade', help='specify teacher generator architecture', choices=['inception_spade']) parser.add_argument('--student_netG', type=str, default='inception_spade', help='specify student generator architecture', choices=['inception_spade']) parser.add_argument( '--teacher_ngf', type=int, default=64, help='the base number of filters of the teacher generator') parser.add_argument( '--student_ngf', type=int, default=48, help='the base number of filters of the student generator') parser.add_argument( '--teacher_norm_G', type=str, default='spadesyncbatch3x3', help= 'instance normalization or batch normalization of the teacher model' ) parser.add_argument( '--student_norm_G', type=str, default='spadesyncbatch3x3', help= 'instance normalization or batch normalization of the student model' ) parser.add_argument('--restore_teacher_G_path', type=str, required=True, help='the path to restore the teacher generator') parser.add_argument('--restore_student_G_path', type=str, default=None, help='the path to restore the student generator') parser.add_argument( '--restore_A_path', type=str, default=None, help='the path to restore the adaptors for distillation') parser.add_argument('--restore_D_path', type=str, default=None, help='the path to restore the discriminator') parser.add_argument('--restore_O_path', type=str, default=None, help='the path to restore the optimizer') parser.add_argument('--lambda_gan', type=float, default=1, help='weight for gan loss') parser.add_argument('--lambda_feat', type=float, default=10, help='weight for gan feature loss') parser.add_argument('--lambda_vgg', type=float, default=10, help='weight for vgg loss') parser.add_argument('--lambda_distill', type=float, default=10, help='weight for vgg loss') parser.add_argument('--distill_G_loss_type', type=str, default='mse', choices=['mse', 'ka'], help='the type of the G distillation loss') parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') parser.add_argument('--no_fid', action='store_true', help='No FID evaluation during training') parser.add_argument('--no_mIoU', action='store_true', help='No mIoU evaluation during training ' '(sometimes because there are CUDA memory)') parser.set_defaults(netD='multi_scale', ndf=64, dataset_mode='cityscapes', batch_size=16, print_freq=50, save_latest_freq=10000000000, save_epoch_freq=10, nepochs=100, nepochs_decay=100, init_type='xavier') return parser def __init__(self, opt): super(SPADEModel, self).__init__(opt) self.model_names = ['G_student', 'G_teacher', 'D'] self.visual_names = ['labels', 'Tfake_B', 'Sfake_B', 'real_B'] self.model_names.append('D') self.loss_names = [ 'G_gan', 'G_feat', 'G_vgg', 'G_distill', 'D_real', 'D_fake' ] if hasattr(opt, 'distiller'): self.modules = SPADEDistillerModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules for i in range(len(self.modules_on_one_gpu.mapping_layers)): self.loss_names.append('G_distill%d' % i) self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers( ) self.optimizers = [self.optimizer_G, self.optimizer_D] if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot and not opt.no_mIoU: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) self.drn_model.to(self.device) self.drn_model.eval() self.eval_dataloader = create_eval_dataloader(self.opt) self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.npz = np.load(opt.real_stat_path) model_profiling(self.modules_on_one_gpu.netG_teacher, self.opt.data_height, self.opt.data_width, channel=self.opt.data_channel, num_forwards=0, verbose=False) model_profiling(self.modules_on_one_gpu.netG_student, self.opt.data_height, self.opt.data_width, channel=self.opt.data_channel, num_forwards=0, verbose=False) print( f'Teacher FLOPs: {self.modules_on_one_gpu.netG_teacher.n_macs}, Student FLOPs: {self.modules_on_one_gpu.netG_student.n_macs}.' ) def forward(self, on_one_gpu=False): if on_one_gpu: self.Tfake_B, self.Sfake_B = self.modules_on_one_gpu( self.input_semantics) else: self.Tfake_B, self.Sfake_B = self.modules(self.input_semantics) def load_networks(self, verbose=True, teacher_only=False, restore_pretrain=True): self.modules_on_one_gpu.load_networks( verbose, teacher_only=teacher_only, restore_pretrain=restore_pretrain) if self.opt.restore_O_path is not None: for i, optimizer in enumerate(self.optimizers): path = '%s-%d.pth' % (self.opt.restore_O_path, i) util.load_optimizer(optimizer, path, verbose) for param_group in optimizer.param_groups: param_group['lr'] = self.opt.lr def save_networks(self, epoch): self.modules_on_one_gpu.save_networks(epoch, self.save_dir) for i, optimizer in enumerate(self.optimizers): save_filename = '%s_optim-%d.pth' % (epoch, i) save_path = os.path.join(self.save_dir, save_filename) torch.save(optimizer.state_dict(), save_path) def evaluate_model(self, step): raise NotImplementedError def optimize_parameters(self, steps): self.set_requires_grad(self.modules_on_one_gpu.netD, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() self.set_requires_grad(self.modules_on_one_gpu.netD, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step()
def main(configs, opt, gpu_id, queue, verbose): opt.gpu_ids = [gpu_id] dataloader = create_dataloader(opt, verbose) model = create_model(opt, verbose) model.setup(opt, verbose) device = model.device if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] inception_model = InceptionV3([block_idx]) inception_model.to(device) inception_model.eval() if 'cityscapes' in opt.dataroot and opt.direction == 'BtoA': drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: drn_model = nn.DataParallel(drn_model, opt.gpu_ids) drn_model.eval() npz = np.load(opt.real_stat_path) results = [] for data_i in dataloader: model.set_input(data_i) break for config in tqdm.tqdm(configs): qualified = True macs, _ = model.profile(config) if macs > opt.budget: qualified = False else: qualified = True fakes, names = [], [] if qualified: for i, data_i in enumerate(dataloader): model.set_input(data_i) model.test(config) fakes.append(model.fake_B.cpu()) for path in model.get_image_paths(): short_path = ntpath.basename(path) name = os.path.splitext(short_path)[0] names.append(name) result = {'config_str': encode_config(config), 'macs': macs} if not opt.no_fid: if qualified: fid = get_fid(fakes, inception_model, npz, device, opt.batch_size, use_tqdm=False) result['fid'] = fid else: result['fid'] = 1e9 if 'cityscapes' in opt.dataroot and opt.direction == 'BtoA': if qualified: mIoU = get_mIoU(fakes, names, drn_model, device, data_dir=opt.cityscapes_path, batch_size=opt.batch_size, num_workers=opt.num_threads, use_tqdm=False) result['mIoU'] = mIoU else: result['mIoU'] = mIoU print(result, flush=True) results.append(result) queue.put(results)