def __init__(self, opt): super(SPADEModel, self).__init__(opt) self.model_names = ['G'] self.visual_names = ['labels', 'fake_B', 'real_B'] self.modules = SPADEModelModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules if opt.isTrain: self.model_names.append('D') self.loss_names = ['G_gan', 'G_feat', 'G_vgg', 'D_real', 'D_fake'] self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers() self.optimizers = [self.optimizer_G, self.optimizer_D] if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot and not opt.no_mIoU: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) self.drn_model.to(self.device) self.drn_model.eval() self.eval_dataloader = create_eval_dataloader(self.opt) self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.npz = np.load(opt.real_stat_path) else: self.modules.eval() self.train_dataloader = create_train_dataloader(opt)
def __init__(self, opt): assert opt.isTrain valid_netGs = [ 'spade', 'mobile_spade', 'super_mobile_spade', 'sub_mobile_spade' ] assert opt.teacher_netG in valid_netGs and opt.student_netG in valid_netGs super(SPADEModel, self).__init__(opt) self.model_names = ['G_student', 'G_teacher', 'D'] self.visual_names = ['labels', 'Tfake_B', 'Sfake_B', 'real_B'] self.model_names.append('D') self.loss_names = [ 'G_gan', 'G_feat', 'G_vgg', 'G_distill', 'D_real', 'D_fake' ] if hasattr(opt, 'distiller'): self.modules = SPADEDistillerModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules else: self.modules = SPADESupernetModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules for i in range(len(self.modules_on_one_gpu.mapping_layers)): self.loss_names.append('G_distill%d' % i) self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers( ) self.optimizers = [self.optimizer_G, self.optimizer_D] if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot and not opt.no_mIoU: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) self.drn_model.to(self.device) self.drn_model.eval() self.eval_dataloader = create_eval_dataloader(self.opt) self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.npz = np.load(opt.real_stat_path)
def __init__(self, opt): super(SPADEModel, self).__init__(opt) self.model_names = ['G_student', 'G_teacher', 'D'] self.visual_names = ['labels', 'Tfake_B', 'Sfake_B', 'real_B'] self.model_names.append('D') self.loss_names = [ 'G_gan', 'G_feat', 'G_vgg', 'G_distill', 'D_real', 'D_fake' ] if hasattr(opt, 'distiller'): self.modules = SPADEDistillerModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules for i in range(len(self.modules_on_one_gpu.mapping_layers)): self.loss_names.append('G_distill%d' % i) self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers( ) self.optimizers = [self.optimizer_G, self.optimizer_D] if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot and not opt.no_mIoU: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) self.drn_model.to(self.device) self.drn_model.eval() self.eval_dataloader = create_eval_dataloader(self.opt) self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.npz = np.load(opt.real_stat_path) model_profiling(self.modules_on_one_gpu.netG_teacher, self.opt.data_height, self.opt.data_width, channel=self.opt.data_channel, num_forwards=0, verbose=False) model_profiling(self.modules_on_one_gpu.netG_student, self.opt.data_height, self.opt.data_width, channel=self.opt.data_channel, num_forwards=0, verbose=False) print( f'Teacher FLOPs: {self.modules_on_one_gpu.netG_teacher.n_macs}, Student FLOPs: {self.modules_on_one_gpu.netG_student.n_macs}.' )
class SPADEModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train): assert isinstance(parser, argparse.ArgumentParser) parser.set_defaults(netG='sub_mobile_spade') parser.add_argument('--separable_conv_norm', type=str, default='instance', choices=('none', 'instance', 'batch'), help='whether to use instance norm for the separable convolutions') parser.add_argument('--norm_G', type=str, default='spadesyncbatch3x3', help='instance normalization or batch normalization') parser.add_argument('--num_upsampling_layers', choices=('normal', 'more', 'most'), default='more', help="If 'more', adds upsampling layer between the two middle resnet blocks. " "If 'most', also add one more upsampling + resnet layer at the end of the generator") if is_train: parser.add_argument('--restore_G_path', type=str, default=None, help='the path to restore the generator') parser.add_argument('--restore_D_path', type=str, default=None, help='the path to restore the discriminator') parser.add_argument('--real_stat_path', type=str, required=True, help='the path to load the groud-truth images information to compute FID.') parser.add_argument('--lambda_gan', type=float, default=1, help='weight for gan loss') parser.add_argument('--lambda_feat', type=float, default=10, help='weight for gan feature loss') parser.add_argument('--lambda_vgg', type=float, default=10, help='weight for vgg loss') parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') parser.add_argument('--no_fid', action='store_true', help='No FID evaluation during training') parser.add_argument('--no_mIoU', action='store_true', help='No mIoU evaluation during training ' '(sometimes because there are CUDA memory)') parser.set_defaults(netD='multi_scale', ndf=64, dataset_mode='cityscapes', batch_size=16, print_freq=50, save_latest_freq=10000000000, save_epoch_freq=10, nepochs=100, nepochs_decay=100, init_type='xavier') parser = networks.modify_commandline_options(parser, is_train) return parser def __init__(self, opt): super(SPADEModel, self).__init__(opt) self.model_names = ['G'] self.visual_names = ['labels', 'fake_B', 'real_B'] self.modules = SPADEModelModules(opt).to(self.device) if len(opt.gpu_ids) > 0: self.modules = DataParallelWithCallback(self.modules, device_ids=opt.gpu_ids) self.modules_on_one_gpu = self.modules.module else: self.modules_on_one_gpu = self.modules if opt.isTrain: self.model_names.append('D') self.loss_names = ['G_gan', 'G_feat', 'G_vgg', 'D_real', 'D_fake'] self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers() self.optimizers = [self.optimizer_G, self.optimizer_D] if not opt.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot and not opt.no_mIoU: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) self.drn_model.to(self.device) self.drn_model.eval() self.eval_dataloader = create_eval_dataloader(self.opt) self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.npz = np.load(opt.real_stat_path) else: self.modules.eval() self.train_dataloader = create_train_dataloader(opt) def set_input(self, input): self.data = input self.image_paths = input['path'] self.labels = input['label'].to(self.device) self.input_semantics, self.real_B = self.preprocess_input(input) def test(self, config=None): with torch.no_grad(): self.forward(on_one_gpu=True, config=config) def preprocess_input(self, data): # move to GPU and change data types data['label'] = data['label'].long() data['label'] = data['label'].to(self.device) data['instance'] = data['instance'].to(self.device) data['image'] = data['image'].to(self.device) # create one-hot label map label_map = data['label'] bs, _, h, w = label_map.size() nc = self.opt.input_nc + 1 if self.opt.contain_dontcare_label \ else self.opt.input_nc input_label = torch.zeros([bs, nc, h, w], device=self.device) input_semantics = input_label.scatter_(1, label_map, 1.0) # concatenate instance map if it exists if not self.opt.no_instance: inst_map = data['instance'] instance_edge_map = self.get_edges(inst_map) input_semantics = torch.cat((input_semantics, instance_edge_map), dim=1) return input_semantics, data['image'] def forward(self, on_one_gpu=False, config=None): if config is not None: self.modules_on_one_gpu.config = config if on_one_gpu: self.fake_B = self.modules_on_one_gpu(self.input_semantics) else: self.fake_B = self.modules(self.input_semantics) def get_edges(self, t): edge = torch.zeros(t.size(), dtype=torch.uint8, device=self.device) edge[:, :, :, 1:] = edge[:, :, :, 1:] | ((t[:, :, :, 1:] != t[:, :, :, :-1]).byte()) edge[:, :, :, :-1] = edge[:, :, :, :-1] | ((t[:, :, :, 1:] != t[:, :, :, :-1]).byte()) edge[:, :, 1:, :] = edge[:, :, 1:, :] | ((t[:, :, 1:, :] != t[:, :, :-1, :]).byte()) edge[:, :, :-1, :] = edge[:, :, :-1, :] | ((t[:, :, 1:, :] != t[:, :, :-1, :]).byte()) return edge.float() def profile(self, config=None, verbose=True): if config is not None: self.modules_on_one_gpu.config = config macs, params = self.modules_on_one_gpu.profile(self.input_semantics[:1]) if verbose: print('MACs: %.3fG\tParams: %.3fM' % (macs / 1e9, params / 1e6), flush=True) return macs, params def backward_G(self): losses = self.modules(self.input_semantics, self.real_B, mode='G_loss') loss_G = losses['loss_G'].mean() for loss_name in self.loss_names: if loss_name.startswith('G'): setattr(self, 'loss_%s' % loss_name, losses[loss_name].detach().mean()) loss_G.backward() def backward_D(self): losses = self.modules(self.input_semantics, self.real_B, mode='D_loss') loss_D = losses['loss_D'].mean() for loss_name in self.loss_names: if loss_name.startswith('D'): setattr(self, 'loss_%s' % loss_name, losses[loss_name].detach().mean()) loss_D.backward() def optimize_parameters(self, steps): # self.forward() self.set_requires_grad(self.modules_on_one_gpu.netD, False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() self.set_requires_grad(self.modules_on_one_gpu.netD, True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() def evaluate_model(self, step): self.is_best = False save_dir = os.path.join(self.opt.log_dir, 'eval', str(step)) os.makedirs(save_dir, exist_ok=True) self.modules_on_one_gpu.netG.eval() torch.cuda.empty_cache() fakes, names = [], [] ret = {} cnt = 0 for i, data_i in enumerate(tqdm(self.eval_dataloader, desc='Eval ', position=2, leave=False)): self.set_input(data_i) self.test() fakes.append(self.fake_B.cpu()) for j in range(len(self.image_paths)): short_path = ntpath.basename(self.image_paths[j]) name = os.path.splitext(short_path)[0] names.append(name) if cnt < 10: input_im = util.tensor2label(self.input_semantics[j], self.opt.input_nc + 2) real_im = util.tensor2im(self.real_B[j]) fake_im = util.tensor2im(self.fake_B[j]) util.save_image(input_im, os.path.join(save_dir, 'input', '%s.png' % name), create_dir=True) util.save_image(real_im, os.path.join(save_dir, 'real', '%s.png' % name), create_dir=True) util.save_image(fake_im, os.path.join(save_dir, 'fake', '%s.png' % name), create_dir=True) cnt += 1 if not self.opt.no_fid: fid = get_fid(fakes, self.inception_model, self.npz, device=self.device, batch_size=self.opt.eval_batch_size, tqdm_position=2) if fid < self.best_fid: self.is_best = True self.best_fid = fid self.fids.append(fid) if len(self.fids) > 3: self.fids.pop(0) ret['metric/fid'] = fid ret['metric/fid-mean'] = sum(self.fids) / len(self.fids) ret['metric/fid-best'] = self.best_fid if 'cityscapes' in self.opt.dataroot and not self.opt.no_mIoU: mIoU = get_cityscapes_mIoU(fakes, names, self.drn_model, self.device, table_path=self.opt.table_path, data_dir=self.opt.cityscapes_path, batch_size=self.opt.eval_batch_size, num_workers=self.opt.num_threads, tqdm_position=2) if mIoU > self.best_mIoU: self.is_best = True self.best_mIoU = mIoU self.mIoUs.append(mIoU) if len(self.mIoUs) > 3: self.mIoUs = self.mIoUs[1:] ret['metric/mIoU'] = mIoU ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs) ret['metric/mIoU-best'] = self.best_mIoU self.modules_on_one_gpu.netG.train() torch.cuda.empty_cache() return ret def print_networks(self): print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self.modules_on_one_gpu, 'net' + name) num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) if hasattr(self.opt, 'log_dir'): with open(os.path.join(self.opt.log_dir, 'net' + name + '.txt'), 'w') as f: f.write(str(net) + '\n') f.write('[Network %s] Total number of parameters : %.3f M\n' % (name, num_params / 1e6)) print('-----------------------------------------------') def load_networks(self, verbose=True): self.modules_on_one_gpu.load_networks(verbose) if self.isTrain and self.opt.restore_O_path is not None: for i, optimizer in enumerate(self.optimizers): path = '%s-%d.pth' % (self.opt.restore_O_path, i) util.load_optimizer(optimizer, path, verbose) if self.opt.no_TTUR: G_lr, D_lr = self.opt.lr, self.opt.lr else: G_lr, D_lr = self.opt.lr / 2, self.opt.lr * 2 for param_group in self.optimizer_G.param_groups: param_group['lr'] = G_lr for param_group in self.optimizer_D.param_groups: param_group['lr'] = D_lr def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str) and hasattr(self, name): visual_ret[name] = getattr(self, name) return visual_ret def save_networks(self, epoch): self.modules_on_one_gpu.save_networks(epoch, self.save_dir) for i, optimizer in enumerate(self.optimizers): save_filename = '%s_optim-%d.pth' % (epoch, i) save_path = os.path.join(self.save_dir, save_filename) torch.save(optimizer.state_dict(), save_path) def calibrate(self, config): self.modules_on_one_gpu.netG.train() config = copy.deepcopy(config) for i, data in enumerate(self.train_dataloader): self.set_input(data) if i == 0: config['calibrate_bn'] = True self.modules_on_one_gpu.config = config self.modules(self.input_semantics, mode='calibrate') self.modules_on_one_gpu.netG.eval()