def main(): args = parse_args() cfg = Config.fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # set random seeds if args.seed is not None: set_random_seed(args.seed, deterministic=args.deterministic) # build the model and load checkpoint model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) model.eval() # load ckpt mmcv.print_log(f'Loading ckpt from {args.checkpoint}', 'mmgen') _ = load_checkpoint(model, args.checkpoint, map_location='cpu') # add dp wrapper model = MMDataParallel(model, device_ids=[0]) pbar = mmcv.ProgressBar(args.num_samples) for sample_iter in range(args.num_samples): outputs = model(None, num_batches=1, get_prev_res=args.save_prev_res) # store results from previous stages if args.save_prev_res: fake_img = outputs['fake_img'] prev_res_list = outputs['prev_res_list'] prev_res_list.append(fake_img) for i, img in enumerate(prev_res_list): img = _tensor2img(img) mmcv.imwrite( img, os.path.join(args.samples_path, f'stage{i}', f'rand_sample_{sample_iter}.png')) # just store the final result else: img = _tensor2img(outputs) mmcv.imwrite( img, os.path.join(args.samples_path, f'rand_sample_{sample_iter}.png')) pbar.update() # change the line after pbar sys.stdout.write('\n')
def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) rank, _ = get_dist_info() dirname = os.path.dirname(args.checkpoint) ckpt = os.path.basename(args.checkpoint) if 'http' in args.checkpoint: log_path = None else: log_name = ckpt.split('.')[0] + '_eval_log' + '.txt' log_path = os.path.join(dirname, log_name) logger = get_root_logger( log_file=log_path, log_level=cfg.log_level, file_mode='a') logger.info('evaluation') # set random seeds if args.seed is not None: if rank == 0: mmcv.print_log(f'set random seed to {args.seed}', 'mmgen') set_random_seed(args.seed, deterministic=args.deterministic) # build the model and load checkpoint model = build_model( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) # sanity check for models without ema if not model.use_ema: args.sample_model = 'orig' mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen') model.eval() if not distributed: _ = load_checkpoint(model, args.checkpoint, map_location='cpu') model = MMDataParallel(model, device_ids=[0]) # build metrics if args.eval: if args.eval[0] == 'none': # only sample images metrics = [] assert args.num_samples is not None and args.num_samples > 0 else: metrics = [ build_metric(cfg.metrics[metric]) for metric in args.eval ] else: metrics = [ build_metric(cfg.metrics[metric]) for metric in cfg.metrics ] basic_table_info = dict( train_cfg=os.path.basename(cfg._filename), ckpt=ckpt, sample_model=args.sample_model) if len(metrics) == 0: basic_table_info['num_samples'] = args.num_samples data_loader = None else: basic_table_info['num_samples'] = -1 # build the dataloader if cfg.data.get('test', None): dataset = build_dataset(cfg.data.test) elif cfg.data.get('val', None): dataset = build_dataset(cfg.data.val) else: dataset = build_dataset(cfg.data.train) data_loader = build_dataloader( dataset, samples_per_gpu=args.batch_size, workers_per_gpu=cfg.data.get('val_workers_per_gpu', cfg.data.workers_per_gpu), dist=distributed, shuffle=True) if args.sample_cfg is None: args.sample_cfg = dict() # online mode will not save samples if args.online and len(metrics) > 0: single_gpu_online_evaluation(model, data_loader, metrics, logger, basic_table_info, args.batch_size, **args.sample_cfg) else: single_gpu_evaluation(model, data_loader, metrics, logger, basic_table_info, args.batch_size, args.samples_path, **args.sample_cfg) else: raise NotImplementedError("We hasn't implemented multi gpu eval yet.")
parser.add_argument('--truncation-mean', type=int, default=4096) parser.add_argument('--noise-channels', type=int, default=512) parser.add_argument('--input-scale', type=int, default=4) parser.add_argument('--sample-cfg', nargs='+', action=DictAction, help='Other customized kwargs for sampling function') # system args parser.add_argument('--num-samples', type=int, default=2) parser.add_argument('--sample-path', type=str, default=None) parser.add_argument('--random-seed', type=int, default=2020) args = parser.parse_args() set_random_seed(args.random_seed) cfg = mmcv.Config.fromfile(args.config) if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True mmcv.print_log('Building models and loading checkpoints', 'mmgen') # build model model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) model.eval() load_checkpoint(model, args.ckpt, map_location='cpu') # get generator
def main(): args = parse_args() cfg = Config.fromfile(args.config) dirname = os.path.dirname(args.checkpoint) ckpt = os.path.basename(args.checkpoint) if 'http' in args.checkpoint: log_path = None else: log_name = ckpt.split('.')[0] + '_eval_log' + '.txt' log_path = os.path.join(dirname, log_name) logger = get_root_logger( log_file=log_path, log_level=cfg.log_level, file_mode='a') logger.info('evaluation') # set random seeds if args.seed is not None: set_random_seed(args.seed, deterministic=args.deterministic) # build the model and load checkpoint model = build_model( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) assert isinstance(model, BaseTranslationModel) # sanity check for models without ema if not model.use_ema: args.sample_model = 'orig' mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen') model.eval() _ = load_checkpoint(model, args.checkpoint, map_location='cpu') model = MMDataParallel(model, device_ids=[0]) # build metrics if args.eval: if args.eval[0] == 'none': # only sample images metrics = [] assert args.num_samples is not None and args.num_samples > 0 else: metrics = [ build_metric(cfg.metrics[metric]) for metric in args.eval ] else: metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics] # get source domain and target domain target_domain = args.target_domain if target_domain is None: target_domain = model.module._default_domain source_domain = model.module.get_other_domains(target_domain)[0] basic_table_info = dict( train_cfg=os.path.basename(cfg._filename), ckpt=ckpt, sample_model=args.sample_model, source_domain=source_domain, target_domain=target_domain) # build the dataloader if len(metrics) == 0: basic_table_info['num_samples'] = args.num_samples data_loader = None else: basic_table_info['num_samples'] = -1 if cfg.data.get('test', None): dataset = build_dataset(cfg.data.test) else: dataset = build_dataset(cfg.data.train) data_loader = build_dataloader( dataset, samples_per_gpu=args.batch_size, workers_per_gpu=cfg.data.get('val_workers_per_gpu', cfg.data.workers_per_gpu), dist=False, shuffle=True) if args.online: single_gpu_online_evaluation(model, data_loader, metrics, logger, basic_table_info, args.batch_size) else: single_gpu_evaluation(model, data_loader, metrics, logger, basic_table_info, args.batch_size, args.samples_path)
def main(): args = parse_args() cfg = Config.fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # set random seeds if args.seed is not None: print('set random seed to', args.seed) set_random_seed(args.seed, deterministic=args.deterministic) # build the model and load checkpoint model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) _ = load_checkpoint(model, args.checkpoint, map_location='cpu') # sanity check for models without ema if not model.use_ema: args.sample_model = 'orig' if args.sample_model == 'ema': generator = model.generator_ema else: generator = model.generator mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen') generator.eval() device = 'cpu' if not args.use_cpu: generator = generator.cuda() device = 'cuda' img_size = min(generator.out_size, 256) transform = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) # read images imgs = [] for imgfile in args.files: img = Image.open(imgfile).convert('RGB') img = transform(img) img = img[[2, 1, 0], ...] imgs.append(img) imgs = torch.stack(imgs, 0).to(device) # get mean and standard deviation of style latents with torch.no_grad(): noise_sample = torch.randn(args.n_mean_latent, generator.style_channels, device=device) latent_out = generator.style_mapping(noise_sample) latent_mean = latent_out.mean(0) latent_std = ((latent_out - latent_mean).pow(2).sum() / args.n_mean_latent)**0.5 latent_in = latent_mean.detach().clone().unsqueeze(0).repeat( imgs.shape[0], 1) if args.w_plus: latent_in = latent_in.unsqueeze(1).repeat(1, generator.num_latents, 1) latent_in.requires_grad = True # define lpips loss percept = PerceptualLoss(use_gpu=device.startswith('cuda')) # initialize layer noises noises_single = generator.make_injected_noise() noises = [] for noise in noises_single: noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) for noise in noises: noise.requires_grad = True optimizer = optim.Adam([latent_in] + noises, lr=args.lr) pbar = tqdm(range(args.total_iters)) # run optimization for i in pbar: t = i / args.total_iters lr = get_lr(t, args.lr, args.lr_rampdown, args.lr_rampup) optimizer.param_groups[0]['lr'] = lr noise_strength = latent_std * args.noise * max( 0, 1 - t / args.noise_ramp)**2 latent_n = latent_noise(latent_in, noise_strength.item()) img_gen = generator([latent_n], input_is_latent=True, injected_noise=noises) batch, channel, height, width = img_gen.shape if height > 256: factor = height // 256 img_gen = img_gen.reshape(batch, channel, height // factor, factor, width // factor, factor) img_gen = img_gen.mean([3, 5]) p_loss = percept(img_gen, imgs).sum() n_loss = noise_regularize(noises) mse_loss = F.mse_loss(img_gen, imgs) loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss optimizer.zero_grad() loss.backward() optimizer.step() noise_normalize_(noises) pbar.set_description( f' perceptual: {p_loss.item():.4f}, noise regularize:' f'{n_loss.item():.4f}, mse: {mse_loss.item():.4f}, lr: {lr:.4f}') results = generator([latent_in.detach().clone()], input_is_latent=True, injected_noise=noises) # rescale value range to [0, 1] results = ((results + 1) / 2) results = results[:, [2, 1, 0], ...] results = results.clamp_(0, 1) mmcv.mkdir_or_exist(args.results_path) # save projection results result_file = {} for i, input_name in enumerate(args.files): noise_single = [] for noise in noises: noise_single.append(noise[i:i + 1]) result_file[input_name] = { 'img': img_gen[i], 'latent': latent_in[i], 'injected_noise': noise_single, } img_name = os.path.splitext( os.path.basename(input_name))[0] + '-project.png' save_image(results[i], os.path.join(args.results_path, img_name)) torch.save(result_file, os.path.join(args.results_path, 'project_result.pt'))
def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids[0:1] warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' 'Because we only support single GPU mode in ' 'non-distributed testing. Use the first GPU ' 'in `gpu_ids` now.') else: cfg.gpu_ids = [args.gpu_id] # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False rank = 0 else: distributed = True init_dist(args.launcher, **cfg.dist_params) rank, world_size = get_dist_info() cfg.gpu_ids = range(world_size) assert args.online or world_size == 1, ( 'We only support online mode for distrbuted evaluation.') dirname = os.path.dirname(args.checkpoint) ckpt = os.path.basename(args.checkpoint) if 'http' in args.checkpoint: log_path = None else: log_name = ckpt.split('.')[0] + '_eval_log' + '.txt' log_path = os.path.join(dirname, log_name) logger = get_root_logger(log_file=log_path, log_level=cfg.log_level, file_mode='a') logger.info('evaluation') # set random seeds if args.seed is not None: if rank == 0: mmcv.print_log(f'set random seed to {args.seed}', 'mmgen') set_random_seed(args.seed, deterministic=args.deterministic) # build the model and load checkpoint model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) # sanity check for models without ema if not model.use_ema: args.sample_model = 'orig' mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen') model.eval() if args.eval: if args.eval[0] == 'none': # only sample images metrics = [] assert args.num_samples is not None and args.num_samples > 0 else: metrics = [ build_metric(cfg.metrics[metric]) for metric in args.eval ] else: metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics] # check metrics for dist evaluation if distributed and metrics: for metric in metrics: assert metric.name in _distributed_metrics, ( f'We only support {_distributed_metrics} for multi gpu ' f'evaluation, but receive {args.eval}.') _ = load_checkpoint(model, args.checkpoint, map_location='cpu') basic_table_info = dict(train_cfg=os.path.basename(cfg._filename), ckpt=ckpt, sample_model=args.sample_model) if len(metrics) == 0: basic_table_info['num_samples'] = args.num_samples data_loader = None else: basic_table_info['num_samples'] = -1 # build the dataloader if cfg.data.get('test', None) and cfg.data.test.get('imgs_root', None): dataset = build_dataset(cfg.data.test) elif cfg.data.get('val', None) and cfg.data.val.get('imgs_root', None): dataset = build_dataset(cfg.data.val) elif cfg.data.get('train', None): # we assume that the train part should work well dataset = build_dataset(cfg.data.train) else: raise RuntimeError('There is no valid dataset config to run, ' 'please check your dataset configs.') # The default loader config loader_cfg = dict(samples_per_gpu=args.batch_size, workers_per_gpu=cfg.data.get( 'val_workers_per_gpu', cfg.data.workers_per_gpu), num_gpus=len(cfg.gpu_ids), dist=distributed, shuffle=True) # The overall dataloader settings loader_cfg.update({ k: v for k, v in cfg.data.items() if k not in [ 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', 'test_dataloader' ] }) # specific config for test loader test_loader_cfg = {**loader_cfg, **cfg.data.get('test_dataloader', {})} data_loader = build_dataloader(dataset, **test_loader_cfg) if args.sample_cfg is None: args.sample_cfg = dict() if not distributed: model = MMDataParallel(model, device_ids=[0]) else: find_unused_parameters = cfg.get('find_unused_parameters', False) model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) # online mode will not save samples if args.online and len(metrics) > 0: online_evaluation(model, data_loader, metrics, logger, basic_table_info, args.batch_size, **args.sample_cfg) else: offline_evaluation(model, data_loader, metrics, logger, basic_table_info, args.batch_size, args.samples_path, **args.sample_cfg)
def main(): args = parse_args() cfg = Config.fromfile(args.config) dirname = os.path.dirname(args.checkpoint) ckpt = os.path.basename(args.checkpoint) if 'http' in args.checkpoint: log_path = None else: log_name = ckpt.split('.')[0] + '_eval_log' + '.txt' log_path = os.path.join(dirname, log_name) logger = get_root_logger(log_file=log_path, log_level=cfg.log_level, file_mode='a') logger.info('evaluation') # set random seeds if args.seed is not None: set_random_seed(args.seed, deterministic=args.deterministic) # build the model and load checkpoint model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) assert isinstance(model, _supported_model) # sanity check for models without ema if not model.use_ema: args.sample_model = 'orig' mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen') model.eval() _ = load_checkpoint(model, args.checkpoint, map_location='cpu') model = MMDataParallel(model, device_ids=[0]) # build metrics if args.eval: if args.eval[0] == 'none': # only sample images metrics = [] assert args.num_samples is not None and args.num_samples > 0 else: metrics = [ build_metric(cfg.metrics[metric]) for metric in args.eval ] else: metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics] basic_table_info = dict(train_cfg=os.path.basename(cfg._filename), ckpt=ckpt, sample_model=args.sample_model) # build the dataloader if len(metrics) == 0: basic_table_info['num_samples'] = args.num_samples data_loader = None else: basic_table_info['num_samples'] = -1 if cfg.data.get('test', None): dataset = build_dataset(cfg.data.test) else: dataset = build_dataset(cfg.data.train) data_loader = build_dataloader(dataset, samples_per_gpu=args.batch_size, workers_per_gpu=cfg.data.get( 'val_workers_per_gpu', cfg.data.workers_per_gpu), dist=False, shuffle=True) # decide samples path samples_path = args.samples_path delete_samples_path = False if samples_path: mmcv.mkdir_or_exist(samples_path) else: temp_path = './work_dirs/temp_samples' # if temp_path exists, add suffix suffix = 1 samples_path = temp_path while os.path.exists(samples_path): samples_path = temp_path + '_' + str(suffix) suffix += 1 os.makedirs(samples_path) delete_samples_path = True # sample images num_exist = len( list( mmcv.scandir(samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG')))) if basic_table_info['num_samples'] > 0: max_num_images = basic_table_info['num_samples'] else: max_num_images = max(metric.num_images for metric in metrics) num_needed = max(max_num_images - num_exist, 0) if num_needed > 0: mmcv.print_log(f'Sample {num_needed} fake images for evaluation', 'mmgen') # define mmcv progress bar pbar = mmcv.ProgressBar(num_needed) # select key to fetch fake images fake_key = 'fake_b' if isinstance(model.module, CycleGAN): fake_key = 'fake_b' if model.module.test_direction == 'a2b' else \ 'fake_a' # if no images, `num_exist` should be zero for begin in range(num_exist, num_needed, args.batch_size): end = min(begin + args.batch_size, max_num_images) # for translation model, we feed them images from dataloader data_loader_iter = iter(data_loader) data_batch = next(data_loader_iter) output_dict = model(test_mode=True, **data_batch) fakes = output_dict[fake_key] pbar.update(end - begin) for i in range(end - begin): images = fakes[i:i + 1] images = ((images + 1) / 2) images = images[:, [2, 1, 0], ...] images = images.clamp_(0, 1) image_name = str(begin + i) + '.png' save_image(images, os.path.join(samples_path, image_name)) if num_needed > 0: sys.stdout.write('\n') # return if only save sampled images if len(metrics) == 0: return # empty cache to release GPU memory torch.cuda.empty_cache() fake_dataloader = make_vanilla_dataloader(samples_path, args.batch_size) # select key to fetch real images if isinstance(model.module, CycleGAN): real_key = 'img_b' if model.module.test_direction == 'a2b' else 'img_a' if model.module.direction == 'b2a': real_key = 'img_a' if real_key == 'img_b' else 'img_b' if isinstance(model.module, Pix2Pix): real_key = 'img_b' if model.module.direction == 'a2b' else 'img_a' for metric in metrics: mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen') metric.prepare() # feed in real images for data in data_loader: reals = data[real_key] num_left = metric.feed(reals, 'reals') if num_left <= 0: break # feed in fake images for data in fake_dataloader: fakes = data['real_img'] num_left = metric.feed(fakes, 'fakes') if num_left <= 0: break metric.summary() table_str = make_metrics_table(basic_table_info['train_cfg'], basic_table_info['ckpt'], basic_table_info['sample_model'], metrics) logger.info('\n' + table_str) if delete_samples_path: shutil.rmtree(samples_path)
def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # import modules from string list. if cfg.get('custom_imports', None): from mmcv.utils import import_modules_from_strings import_modules_from_strings(**cfg['custom_imports']) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) if args.resume_from is not None: cfg.resume_from = args.resume_from if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids else: cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # re-set gpu_ids with distributed training mode _, world_size = get_dist_info() cfg.gpu_ids = range(world_size) # create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # dump config cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, f'{timestamp}.log') logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # init the meta dict to record some important information such as # environment info and seed, which will be logged meta = dict() # log env info env_info_dict = collect_env() env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) dash_line = '-' * 60 + '\n' logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line) meta['env_info'] = env_info meta['config'] = cfg.pretty_text # log some basic info logger.info(f'Distributed training: {distributed}') logger.info(f'Config:\n{cfg.pretty_text}') # set random seeds if args.seed is not None: logger.info(f'Set random seed to {args.seed}, ' f'deterministic: {args.deterministic}') set_random_seed(args.seed, deterministic=args.deterministic) cfg.seed = args.seed meta['seed'] = args.seed meta['exp_name'] = osp.basename(args.config) model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) datasets = [build_dataset(cfg.data.train)] if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) val_dataset.pipeline = cfg.data.train.pipeline datasets.append(build_dataset(val_dataset)) if cfg.checkpoint_config is not None: # save mmgen version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict(mmgen_version=__version__ + get_git_hash()[:7]) train_model(model, datasets, cfg, distributed=distributed, validate=(not args.no_validate), timestamp=timestamp, meta=meta)
def main(): args = parse_args() cfg = Config.fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # set random seeds if args.seed is not None: print('set random seed to', args.seed) set_random_seed(args.seed, deterministic=args.deterministic) # build the model and load checkpoint model = build_model( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) _ = load_checkpoint(model, args.checkpoint, map_location='cpu') # sanity check for models without ema if not model.use_ema: args.sample_model = 'orig' if args.sample_model == 'ema': generator = model.generator_ema else: generator = model.generator mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen') mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen') mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen') generator.eval() if not args.use_cpu: generator = generator.cuda() if args.show_mode == 'sequence': assert args.endpoint >= 2 else: assert args.endpoint >= 2 and args.endpoint % 2 == 0 kwargs = dict(max_batch_size=args.batch_size) if args.sample_cfg is None: args.sample_cfg = dict() kwargs.update(args.sample_cfg) # get noises corresponding to each endpoint noise_batch = batch_inference( generator, None, num_batches=args.endpoint, dict_key='noise_batch' if args.space == 'z' else 'latent', return_noise=True, **kwargs) if args.space == 'w': kwargs['truncation_latent'] = generator.get_mean_latent() kwargs['input_is_latent'] = True if args.show_mode == 'sequence': results = sample_from_path(generator, noise_batch[:-1, ], noise_batch[1:, ], args.interval, args.interp_mode, args.space, **kwargs) else: results = sample_from_path(generator, noise_batch[::2, ], noise_batch[1::2, ], args.interval, args.interp_mode, args.space, **kwargs) # reorder results results = torch.stack(results).permute(1, 0, 2, 3, 4) _, _, ch, h, w = results.shape results = results.reshape(-1, ch, h, w) # rescale value range to [0, 1] results = ((results + 1) / 2) results = results[:, [2, 1, 0], ...] results = results.clamp_(0, 1) # save image mmcv.mkdir_or_exist(args.samples_path) if args.show_mode == 'sequence': for i in range(results.shape[0]): image = results[i:i + 1] save_image( image, os.path.join(args.samples_path, '{:0>5d}'.format(i) + '.png')) else: save_image( results, os.path.join(args.samples_path, 'group.png'), nrow=args.interval)
def main(): args = parse_args() # set cudnn_benchmark cfg = Config.fromfile(args.config) if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # set random seeds if args.seed is not None: print('set random seed to', args.seed) set_random_seed(args.seed, deterministic=args.deterministic) os.makedirs(args.results_dir, exist_ok=True) text_inputs = torch.cat([clip.tokenize(args.description)]).cuda() model = init_model(args.config, args.checkpoint, device='cpu') g_ema = model.generator_ema g_ema.eval() if not args.use_cpu: g_ema = g_ema.cuda() mean_latent = g_ema.get_mean_latent() # if given proj_latent if args.proj_latent is not None: mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen') proj_file = torch.load(args.proj_latent) proj_n = len(proj_file) assert proj_n == 1 noise_batch = [] for img_path in proj_file: noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0)) latent_code_init = torch.cat(noise_batch, dim=0).cuda() elif args.mode == 'edit': latent_code_init_not_trunc = torch.randn(1, 512).cuda() with torch.no_grad(): results = g_ema([latent_code_init_not_trunc], return_latents=True, truncation=args.truncation, truncation_latent=mean_latent) latent_code_init = results['latent'] else: latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1) with torch.no_grad(): img_orig = g_ema([latent_code_init], input_is_latent=True, randomize_noise=False) latent = latent_code_init.detach().clone() latent.requires_grad = True clip_loss = CLIPLoss(clip_model=dict(in_size=g_ema.out_size)) id_loss = FaceIdLoss( facenet=dict(type='ArcFace', ir_se50_weights=None, device='cuda')) optimizer = optim.Adam([latent], lr=args.lr) pbar = tqdm(range(args.step)) mmcv.print_log(f'Description: {args.description}') for i in pbar: t = i / args.step lr = get_lr(t, args.lr) optimizer.param_groups[0]['lr'] = lr img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False) img_gen = img_gen[:, [2, 1, 0], ...] # clip loss c_loss = clip_loss(image=img_gen, text=text_inputs) if args.id_lambda > 0: i_loss = id_loss(pred=img_gen, gt=img_orig)[0] else: i_loss = 0 if args.mode == 'edit': l2_loss = ((latent_code_init - latent)**2).sum() loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss else: loss = c_loss optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description((f'loss: {loss.item():.4f};')) if args.save_interval > 0 and (i % args.save_interval == 0): with torch.no_grad(): img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False) img_gen = img_gen[:, [2, 1, 0], ...] torchvision.utils.save_image(img_gen, os.path.join( args.results_dir, f'{str(i).zfill(5)}.png'), normalize=True, range=(-1, 1)) if args.mode == 'edit': img_orig = img_orig[:, [2, 1, 0], ...] final_result = torch.cat([img_orig, img_gen]) else: final_result = img_gen torchvision.utils.save_image(final_result.detach().cpu(), os.path.join(args.results_dir, 'final_result.png'), normalize=True, scale_each=True, range=(-1, 1))
def main(): args = parse_args() cfg = Config.fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # set random seeds if args.seed is not None: print('set random seed to', args.seed) set_random_seed(args.seed, deterministic=args.deterministic) # build the model and load checkpoint model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) _ = load_checkpoint(model, args.checkpoint, map_location='cpu') # sanity check for models without ema if not model.use_ema: args.sample_model = 'orig' if args.sample_model == 'ema': generator = model.generator_ema else: generator = model.generator mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen') mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen') mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen') generator.eval() if not args.use_cpu: generator = generator.cuda() # if given proj_latent, reset args.endpoint if args.proj_latent is not None: mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen') proj_file = torch.load(args.proj_latent) proj_n = len(proj_file) setattr(args, 'endpoint', proj_n) assert args.space == 'w', 'Projected latent are w or w-plus latent.' noise_batch = [] for img_path in proj_file: noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0)) noise_batch = torch.cat(noise_batch, dim=0).cuda() if args.use_cpu: noise_batch = noise_batch.to('cpu') if args.show_mode == 'sequence': assert args.endpoint >= 2 else: assert args.endpoint >= 2 and args.endpoint % 2 == 0,\ '''We need paired images in group mode, so keep endpoint an even number''' kwargs = dict(max_batch_size=args.batch_size) if args.sample_cfg is None: args.sample_cfg = dict() kwargs.update(args.sample_cfg) # remind users to fixed injected noise if kwargs.get('randomize_noise', 'True'): mmcv.print_log( '''Hint: For Style-Based GAN, you can add `--sample-cfg randomize_noise=False` to fix injected noises''', 'mmgen') # get noises corresponding to each endpoint if not args.proj_latent: noise_batch = batch_inference( generator, None, num_batches=args.endpoint, dict_key='noise_batch' if args.space == 'z' else 'latent', return_noise=True, **kwargs) if args.space == 'w': kwargs['truncation_latent'] = generator.get_mean_latent() kwargs['input_is_latent'] = True if args.show_mode == 'sequence': results = sample_from_path(generator, noise_batch[:-1, ], noise_batch[1:, ], args.interval, args.interp_mode, args.space, **kwargs) else: results = sample_from_path(generator, noise_batch[::2, ], noise_batch[1::2, ], args.interval, args.interp_mode, args.space, **kwargs) # reorder results results = torch.stack(results).permute(1, 0, 2, 3, 4) _, _, ch, h, w = results.shape results = results.reshape(-1, ch, h, w) # rescale value range to [0, 1] results = ((results + 1) / 2) results = results[:, [2, 1, 0], ...] results = results.clamp_(0, 1) # save image mmcv.mkdir_or_exist(args.samples_path) if args.show_mode == 'sequence': if args.export_video: # render video. video_out = imageio.get_writer(os.path.join( args.samples_path, 'lerp.mp4'), mode='I', fps=60, codec='libx264', bitrate='12M') video_out = layout_grid(video_out, results) video_out.close() else: for i in range(results.shape[0]): image = results[i:i + 1] save_image( image, os.path.join(args.samples_path, '{:0>5d}'.format(i) + '.png')) else: if args.export_video: # render video. video_out = imageio.get_writer(os.path.join( args.samples_path, 'lerp.mp4'), mode='I', fps=60, codec='libx264', bitrate='12M') n_pair = args.endpoint // 2 grid_w, grid_h = crack_integer(n_pair) video_out = layout_grid(video_out, results, grid_h=grid_h, grid_w=grid_w) video_out.close() else: save_image(results, os.path.join(args.samples_path, 'group.png'), nrow=args.interval)