def main(args): root_dir = Path(args.autoencoder_checkpoint).parent.parent output_dir = root_dir / args.output_dir output_dir.mkdir(exist_ok=True, parents=True) config = load_config(args.autoencoder_checkpoint, None) config['batch_size'] = 1 autoencoder = get_autoencoder(config).to(args.device) autoencoder = load_weights(autoencoder, args.autoencoder_checkpoint, key='autoencoder') input_image = Path(args.image) data_loader = build_data_loader(input_image, config, config['absolute'], shuffle_off=True, dataset_class=DemoDataset) image = next(iter(data_loader)) image = {k: v.to(args.device) for k, v in image.items()} reconstructed = Image.fromarray( make_image(autoencoder(image['input_image'])[0].squeeze(0))) output_name = Path( args.output_dir ) / f"reconstructed_{input_image.stem}_stylegan_{config['stylegan_variant']}_{'w_only' if config['w_only'] else 'w_plus'}.png" reconstructed.save(output_name)
def embed_images( args: argparse.Namespace, config: dict, dataset: Path ) -> Tuple[numpy.ndarray, List[numpy.ndarray], numpy.ndarray]: data_loader = build_data_loader(dataset, config, config['absolute'], shuffle_off=True) if args.num_samples is not None: random.seed(args.seed if args.seed != 'none' else None) data_loader.dataset.image_data = random.sample( data_loader.dataset.image_data, args.num_samples) autoencoder = get_autoencoder(config).to(args.device) autoencoder = load_weights(autoencoder, args.autoencoder_checkpoint, key='autoencoder') latent_codes = [] noises = None image_names = [] for idx, batch in enumerate(tqdm(data_loader)): if isinstance(batch, dict): batch = batch['input_image'] batch = batch.to(args.device) with torch.no_grad(): latents: Latents = autoencoder.encode(batch) latents = latents.numpy() latent_codes.append(latents.latent) if noises is None: noises = [[noise] for noise in latents.noise] else: for noise, latent_noise in zip(noises, latents.noise): noise.append(latent_noise) for batch_idx in range(len(batch)): image_names.append( data_loader.dataset.image_data[idx * config['batch_size'] + batch_idx]) latent_codes = numpy.concatenate(latent_codes, axis=0) noises = [numpy.concatenate(noise, axis=0) for noise in noises] return latent_codes, noises, numpy.array(image_names)
def evaluate_denoising(args): config = load_config(args.model_checkpoint, None) args.test_dataset = Path(args.test_dataset) assert config['denoising'] is True or config['black_and_white_denoising'] is True, "you are supplying a train run that has not been trained for denoising! Aborting" autoencoder = get_autoencoder(config).to(args.device) autoencoder = load_weights(autoencoder, args.model_checkpoint, key='autoencoder', strict=True) config['batch_size'] = 1 data_loader = build_data_loader(args.test_dataset, config, config['absolute'], shuffle_off=True, dataset_class=DenoisingEvaluationDataset) metrics = defaultdict(list) psnr_ssim_evaluator = PSNRSSIMEvaluator() train_run_root_dir = Path(args.model_checkpoint).parent.parent evaluation_root = train_run_root_dir / 'evaluation' / f"denoise_{args.dataset_name}" evaluation_root.mkdir(parents=True, exist_ok=True) for i, batch in tenumerate(data_loader, leave=False): batch = {k: v.to(args.device) for k, v in batch.items()} with torch.no_grad(): denoised = autoencoder(batch['noisy']) noisy = clamp_and_unnormalize(batch['noisy']) original = clamp_and_unnormalize(batch['original']) denoised = clamp_and_unnormalize(denoised) if args.save: save_dir = evaluation_root / "qualitative" / args.test_dataset.stem save_dir.mkdir(exist_ok=True, parents=True) save_images([original[0], noisy[0], denoised[0]], save_dir, i) psnr, ssim = psnr_ssim_evaluator.psnr_and_ssim(denoised, original) metrics['psnr'].append(float(psnr.cpu().numpy())) metrics['ssim'].append(float(ssim.cpu().numpy())) metrics = {k: statistics.mean(v) for k, v in metrics.items()} evaluation_file = evaluation_root / f'denoising_{args.test_dataset.stem}.json' with evaluation_file.open('w') as f: json.dump(metrics, f, indent='\t')
def evaluate_checkpoint(checkpoint: str, dataset: dict, args: argparse.Namespace): checkpoint = Path(checkpoint) train_run_root_dir = checkpoint.parent.parent evaluation_root = train_run_root_dir / 'evaluation' evaluation_root.mkdir(exist_ok=True) dataset_name = dataset.pop('name') to_evaluate = has_not_been_evaluated(checkpoint.name, dataset_name, evaluation_root) if not args.fid: to_evaluate['fid'] = False if not args.reconstruction: to_evaluate['reconstruction'] = False if not any(to_evaluate.values()): # there is nothing to evaluate return config = load_config(checkpoint, None) dataset = {k: Path(v) for k, v in dataset.items()} autoencoder = get_autoencoder(config).to('cuda') autoencoder = load_weights(autoencoder, checkpoint, key='autoencoder', strict=True) config['batch_size'] = 1 dataset_class = get_dataset_class(argparse.Namespace(**config)) data_loaders = { key: build_data_loader(value, config, config['absolute'], shuffle_off=True, dataset_class=dataset_class) for key, value in dataset.items() } if to_evaluate['fid']: fid_result = evaluate_fid(autoencoder, data_loaders, dataset) save_eval_result(fid_result, "fid", evaluation_root, dataset_name, checkpoint.name) if to_evaluate['reconstruction']: reconstruction_result = evaluate_reconstruction(autoencoder, data_loaders) save_eval_result(reconstruction_result, "reconstruction", evaluation_root, dataset_name, checkpoint.name) del autoencoder torch.cuda.empty_cache()
def main(args): embedding_dir = Path(args.embedding_file).parent embedded_data = numpy.load(args.embedding_file, mmap_mode='r') checkpoint_for_embedding = embedding_dir.parent / 'checkpoints' / f"{Path(args.embedding_file).stem.split('_')[-3]}.pt" config = load_config(checkpoint_for_embedding, None) autoencoder = get_autoencoder(config).to(args.device) autoencoder = load_weights(autoencoder, checkpoint_for_embedding, key='autoencoder', strict=True) num_images = len(embedded_data['image_names']) interpolation_dir = embedding_dir / 'interpolations' interpolation_dir.mkdir(parents=True, exist_ok=True) is_w_plus = not config['w_only'] for _ in range(args.num_images): start_image_idx, end_image_idx = random.sample(list(range(num_images)), k=2) start_latent, start_noises = load_embeddings(embedded_data, start_image_idx) end_latent, end_noises = load_embeddings(embedded_data, end_image_idx) for steps in args.steps: result = make_interpolation_image(steps, args.device, autoencoder, is_w_plus, start_latent, end_latent, start_noises, end_noises) result.save( str(interpolation_dir / f"{start_image_idx}_{end_image_idx}_all_{steps}_steps.png") )
def main(args: argparse.Namespace): dest_dir = Path(args.model_checkpoint).parent.parent / 'evaluation' dest_dir.mkdir(parents=True, exist_ok=True) config = load_config(args.model_checkpoint, None) dataset = Path(args.dataset) config['batch_size'] = args.batch_size data_loader = build_data_loader(dataset, config, config['absolute'], shuffle_off=True) fid_calculator = FID(args.num_samples, device=args.device) autoencoder = get_autoencoder(config).to(args.device) autoencoder = load_weights(autoencoder, args.model_checkpoint, key='autoencoder') fid_score = fid_calculator(autoencoder, data_loader, args.dataset) save_fid_score(fid_score, dest_dir, args.dataset_name) print(f"FID Score for {args.dataset_name} is {fid_score}.")
def main(args): checkpoint_path = Path(args.model_checkpoint) config = load_config(checkpoint_path, None) autoencoder = get_autoencoder(config).to(args.device) load_weights(autoencoder, checkpoint_path, key='autoencoder', strict=True) config['batch_size'] = 1 if args.generate: data_loader = build_latent_and_noise_generator(autoencoder, config) else: data_loader = build_data_loader(args.images, config, args.absolute, shuffle_off=True) noise_dest_dir = checkpoint_path.parent.parent / "noise_maps" noise_dest_dir.mkdir(parents=True, exist_ok=True) num_images = 0 for idx, batch in enumerate(tqdm(data_loader, total=args.num_images)): batch = batch.to(args.device) if args.generate: latents = batch image_names = [Path(f"generate_{idx}.png")] else: with torch.no_grad(): latents: Latents = autoencoder.encode(batch) image_names = [ Path( data_loader.dataset.image_data[idx * config['batch_size'] + batch_idx]) for batch_idx in range(len(batch)) ] if args.shift_noise: noise_shifted_tensors = render_with_shifted_noise( autoencoder, latents, args.rounds) images = [] for noise_tensors in latents.noise: noise_images = make_image(noise_tensors, normalize_func=noise_normalize) images.append([ Image.fromarray(im).resize( (config['image_size'], config['image_size']), Image.NEAREST) for im in noise_images ]) for batch_idx, (image, orig_file_name) in enumerate(zip(batch, image_names)): full_image = Image.new( 'RGB', ((len(images) + 1) * config['image_size'], config['image_size'] if not args.shift_noise else config['image_size'] * (args.rounds + 1))) if not args.generate: full_image.paste(Image.fromarray(make_image(image)), (0, 0)) for i, noise_images in enumerate(images): full_image.paste(noise_images[batch_idx], ((i + 1) * config['image_size'], 0)) if args.shift_noise: for i, shifted_images in enumerate(noise_shifted_tensors): for j, shifted_image in enumerate(shifted_images): full_image.paste(shifted_image, (j * config['image_size'], (i + 1) * config['image_size'])) full_image.save(noise_dest_dir / f"{orig_file_name.stem}_noise.png") num_images += len(image_names) if num_images >= args.num_images: break
def main(args, rank, world_size): config = load_yaml_config(args.config) config = merge_config_and_args(config, args) dataset_class = get_dataset_class(args) train_data_loader = build_data_loader(args.images, config, args.absolute, dataset_class=dataset_class) autoencoder = get_autoencoder(config, init_ckpt=args.stylegan_checkpoint) discriminator = None if args.use_discriminator: discriminator = get_discriminator(config) if args.disable_update_for == 'latent': assert args.autoencoder is not None, "if you want to only train noise, we need an autoencoder checkoint!" print( f"Loading encoder weights from {args.autoencoder} for noise-only training." ) load_weights(autoencoder.encoder, args.autoencoder, key='encoder', strict=False) elif args.autoencoder is not None: print(f"Loading all weights from {args.autoencoder}.") load_weights(autoencoder, args.autoencoder, key='autoencoder') optimizer_opts = { 'betas': (config['beta1'], config['beta2']), 'weight_decay': config['weight_decay'], 'lr': float(config['lr']), } if args.disable_update_for != 'none': if float(config['lr_to_noise']) != float(config['lr']): print( "Warning: updates for some parts of the networks are disabled. " f"Therefore 'lr_to_noise'={config['lr_to_noise']} is ignored.") optimizer = GradientClipAdam(autoencoder.trainable_parameters(), **optimizer_opts) else: main_param_group, noise_param_group = autoencoder.trainable_parameters( as_groups=(["to_noise", "intermediate_to_noise"], )) noise_param_group['lr'] = float(config['lr_to_noise']) optimizer = GradientClipAdam([main_param_group, noise_param_group], **optimizer_opts) if world_size > 1: distributed = functools.partial(DDP, device_ids=[rank], find_unused_parameters=True, broadcast_buffers=False, output_device=rank) autoencoder = distributed(autoencoder.to('cuda')) if discriminator is not None: discriminator = distributed(discriminator.to('cuda')) else: autoencoder = autoencoder.to('cuda') if discriminator is not None: discriminator = discriminator.to('cuda') if discriminator is not None: discriminator_optimizer = GradientClipAdam(discriminator.parameters(), **optimizer_opts) updater = AutoencoderDiscriminatorUpdater( iterators={'images': train_data_loader}, networks={ 'autoencoder': autoencoder, 'discriminator': discriminator }, optimizers={ 'main': optimizer, 'discriminator': discriminator_optimizer }, device='cuda', copy_to_device=world_size == 1, disable_update_for=args.disable_update_for, ) else: updater = AutoencoderUpdater( iterators={'images': train_data_loader}, networks={'autoencoder': autoencoder}, optimizers={'main': optimizer}, device='cuda', copy_to_device=world_size == 1, disable_update_for=args.disable_update_for, ) trainer = DistributedTrainer(updater, stop_trigger=get_trigger( (config['max_iter'], 'iteration'))) logger = WandBLogger( args.log_dir, args, config, os.path.dirname(os.path.realpath(__file__)), trigger=get_trigger((config['log_iter'], 'iteration')), master=rank == 0, project_name="One Model to Generate them All", run_name=args.log_name, ) if args.val_images is not None: val_data_loader = build_data_loader(args.val_images, config, args.absolute, shuffle_off=True, dataset_class=dataset_class) evaluator = Evaluator(val_data_loader, logger, AutoEncoderEvalFunc(autoencoder, rank), rank, trigger=get_trigger((1, 'epoch'))) trainer.extend(evaluator) fid_extension = FIDScore( autoencoder if not isinstance(autoencoder, DDP) else autoencoder.module, val_data_loader if args.val_images is not None else train_data_loader, dataset_path=args.val_images if args.val_images is not None else args.images, trigger=(1, 'epoch')) trainer.extend(fid_extension) if rank == 0: snapshot_autoencoder = autoencoder if not isinstance( autoencoder, DDP) else autoencoder.module snapshotter = Snapshotter( { 'autoencoder': snapshot_autoencoder, 'encoder': snapshot_autoencoder.encoder, 'decoder': snapshot_autoencoder.decoder, 'optimizer': optimizer, }, args.log_dir, trigger=get_trigger((config['snapshot_save_iter'], 'iteration'))) trainer.extend(snapshotter) plot_images = [] if args.val_images is not None: def fill_plot_images(data_loader): image_list = [] num_images = 0 for batch in data_loader: for image in batch['input_image']: image_list.append(image) num_images += 1 if num_images > config['display_size']: return image_list raise RuntimeError( f"Could not gather enough plot images for display size {config['display_size']}." ) plot_images = fill_plot_images(val_data_loader) else: for i in range(config['display_size']): if hasattr(train_data_loader.sampler, 'set_epoch'): train_data_loader.sampler.set_epoch(i) plot_images.append( next(iter(train_data_loader))['input_image'][0]) image_plotter = ImagePlotter(plot_images, [autoencoder], args.log_dir, trigger=get_trigger( (config['image_save_iter'], 'iteration')), plot_to_logger=True) trainer.extend(image_plotter) schedulers = { "encoder": CosineAnnealingLR(optimizer, config["max_iter"], eta_min=1e-8) } lr_scheduler = LRScheduler(schedulers, trigger=get_trigger((1, 'iteration'))) trainer.extend(lr_scheduler) trainer.extend(logger) synchronize() trainer.train()