def test(self): score = get_model(self.config) score = torch.nn.DataParallel(score) sigmas = get_sigmas(self.config) dataset, test_dataset = get_dataset(self.args, self.config) test_dataloader = DataLoader( test_dataset, batch_size=self.config.test.batch_size, shuffle=True, num_workers=self.config.data.num_workers, drop_last=True, ) verbose = False for ckpt in tqdm.tqdm( range(self.config.test.begin_ckpt, self.config.test.end_ckpt + 1, 5000), desc="processing ckpt:", ): states = torch.load( os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"), map_location=self.config.device, ) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() step = 0 mean_loss = 0.0 mean_grad_norm = 0.0 average_grad_scale = 0.0 for x, y in test_dataloader: step += 1 x = x.to(self.config.device) x = data_transform(self.config, x) test_loss = anneal_sliced_score_estimation_vr( score, x, sigmas, None, self.config.training.anneal_power ) if verbose: logging.info( "step: {}, test_loss: {}".format(step, test_loss.item()) ) mean_loss += test_loss.item() mean_loss /= step mean_grad_norm /= step average_grad_scale /= step logging.info("ckpt: {}, average test loss: {}".format(ckpt, mean_loss))
def sample(self): model = Model(self.config) if not self.args.use_pretrained: if getattr(self.config.sampling, "ckpt_id", None) is None: states = paddle.load( os.path.join(self.args.log_path, "ckpt.pdl")) else: states = paddle.load( os.path.join(self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pdl")) model = model model = paddle.DataParallel(model) model.set_state_dict({ k.split("$model_")[-1]: v for k, v in states.items() if "$model_" in k }) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(model) ema_helper.set_state_dict({ k.split("$ema_")[-1]: v for k, v in states.items() if "$ema_" in k }) ema_helper.ema(model) else: ema_helper = None else: # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion if self.config.data.dataset == "CIFAR10": name = "cifar10" elif self.config.data.dataset == "LSUN": name = f"lsun_{self.config.data.category}" else: raise ValueError ckpt = get_ckpt_path(f"ema_{name}") print("Loading checkpoint {}".format(ckpt)) model.set_state_dict(paddle.load(ckpt)) model = paddle.DataParallel(model) model.eval() if self.args.fid: self.sample_fid(model) elif self.args.interpolation: self.sample_interpolation(model) elif self.args.sequence: self.sample_sequence(model) else: raise NotImplementedError("Sample procedeure not defined")
def _load_states(self, score): if self.config.sampling.ckpt_id is None: path = os.path.join(self.args.log_path, 'checkpoint.pth') else: path = os.path.join(self.args.log_path, f'checkpoint_{self.config.sampling.ckpt_id}.pth') states = torch.load(path, map_location=self.args.device) # score.load_state_dict(states[0], strict=True) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) del states return score
def sample(self): D = DensityRatioEstNet(self.config.model.ngf_d, self.config.data.image_size, self.config.data.channels).to(self.device) D.load_state_dict(torch.load(os.path.join(self.args.log_path, f"ckpt_DRE_{self.args.sigma_sq}_{self.args.tau}.pth"))['D']) S = Model(self.config) if getattr(self.config.sampling, "ckpt_id", None) is None: states = torch.load( os.path.join(self.args.log_path, "ckpt.pth"), map_location=self.config.device, ) else: states = torch.load( os.path.join( self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pth" ), map_location=self.config.device, ) S = S.to(self.device) S = torch.nn.DataParallel(S) S.load_state_dict(states[0], strict=True) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(S) ema_helper.load_state_dict(states[-1]) ema_helper.ema(S) else: ema_helper = None S.eval() if self.args.fid: self.sample_fid(S, D) elif self.args.interpolation: self.sample_interpolation(S) elif self.args.inpainting: self.sample_inpainting(S) elif self.args.sbp: self.sample_sbp(S, D) else: raise NotImplementedError("Sample procedeure not defined")
def fast_fid(self): ### Test the fids of ensembled checkpoints. ### Shouldn't be used for models with ema if self.config.fast_fid.ensemble: if self.config.model.ema: raise RuntimeError("Cannot apply ensembling to models with EMA.") self.fast_ensemble_fid() return from evaluation.fid_score import get_fid, get_fid_stats_path import pickle score = get_model(self.config) score = torch.nn.DataParallel(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm(range(self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000), desc="processing ckpt"): states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{ckpt}.pth'), map_location=self.config.device) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() num_iters = self.config.fast_fid.num_samples // self.config.fast_fid.batch_size output_path = os.path.join(self.args.image_folder, 'ckpt_{}'.format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): init_samples = torch.rand(self.config.fast_fid.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics(init_samples, score, sigmas, self.config.fast_fid.n_steps_each, self.config.fast_fid.step_lr, verbose=self.config.fast_fid.verbose, denoise=self.config.sampling.denoise) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view(self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) save_image(sample, os.path.join(output_path, 'sample_{}.png'.format(id))) stat_path = get_fid_stats_path(self.args, self.config, download=True) fid = get_fid(stat_path, output_path) fids[ckpt] = fid print("ckpt: {}, fid: {}".format(ckpt, fid)) with open(os.path.join(self.args.image_folder, 'fids.pickle'), 'wb') as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)
def sample(self): if self.config.sampling.ckpt_id is None: states = torch.load(os.path.join(self.args.log_path, 'checkpoint.pth'), map_location=self.config.device) else: states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{self.config.sampling.ckpt_id}.pth'), map_location=self.config.device) score = get_model(self.config) score = torch.nn.DataParallel(score) score.load_state_dict(states[0], strict=True) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() dataset, _ = get_dataset(self.args, self.config) dataloader = DataLoader(dataset, batch_size=self.config.sampling.batch_size, shuffle=True, num_workers=4) score.eval() if not self.config.sampling.fid: if self.config.sampling.inpainting: data_iter = iter(dataloader) refer_images, _ = next(data_iter) refer_images = refer_images.to(self.config.device) width = int(np.sqrt(self.config.sampling.batch_size)) init_samples = torch.rand(width, width, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics_inpainting(init_samples, refer_images[:width, ...], score, sigmas, self.config.data.image_size, self.config.sampling.n_steps_each, self.config.sampling.step_lr) torch.save(refer_images[:width, ...], os.path.join(self.args.image_folder, 'refer_image.pth')) refer_images = refer_images[:width, None, ...].expand(-1, width, -1, -1, -1).reshape(-1, *refer_images.shape[ 1:]) save_image(refer_images, os.path.join(self.args.image_folder, 'refer_image.png'), nrow=width) if not self.config.sampling.final_only: for i, sample in enumerate(tqdm.tqdm(all_samples)): sample = sample.view(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size))) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i))) torch.save(sample, os.path.join(self.args.image_folder, 'completion_{}.pth'.format(i))) else: sample = all_samples[-1].view(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size))) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(self.config.sampling.ckpt_id))) torch.save(sample, os.path.join(self.args.image_folder, 'completion_{}.pth'.format(self.config.sampling.ckpt_id))) elif self.config.sampling.interpolation: if self.config.sampling.data_init: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) init_samples = samples + sigmas_th[0] * torch.randn_like(samples) else: init_samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics_interpolation(init_samples, score, sigmas, self.config.sampling.n_interpolations, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=True, final_only=self.config.sampling.final_only) if not self.config.sampling.final_only: for i, sample in tqdm.tqdm(enumerate(all_samples), total=len(all_samples), desc="saving image samples"): sample = sample.view(sample.shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, nrow=self.config.sampling.n_interpolations) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i))) torch.save(sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(i))) else: sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, self.config.sampling.n_interpolations) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(self.config.sampling.ckpt_id))) torch.save(sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(self.config.sampling.ckpt_id))) else: if self.config.sampling.data_init: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) init_samples = samples + sigmas_th[0] * torch.randn_like(samples) else: init_samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics(init_samples, score, sigmas, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=True, final_only=self.config.sampling.final_only, denoise=self.config.sampling.denoise) if not self.config.sampling.final_only: for i, sample in tqdm.tqdm(enumerate(all_samples), total=len(all_samples), desc="saving image samples"): sample = sample.view(sample.shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size))) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i))) torch.save(sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(i))) else: sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size))) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(self.config.sampling.ckpt_id))) torch.save(sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(self.config.sampling.ckpt_id))) else: total_n_samples = self.config.sampling.num_samples4fid n_rounds = total_n_samples // self.config.sampling.batch_size if self.config.sampling.data_init: dataloader = DataLoader(dataset, batch_size=self.config.sampling.batch_size, shuffle=True, num_workers=4) data_iter = iter(dataloader) img_id = 0 for _ in tqdm.tqdm(range(n_rounds), desc='Generating image samples for FID/inception score evaluation'): if self.config.sampling.data_init: try: samples, _ = next(data_iter) except StopIteration: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) samples = samples + sigmas_th[0] * torch.randn_like(samples) else: samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) samples = data_transform(self.config, samples) all_samples = anneal_Langevin_dynamics(samples, score, sigmas, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=False, denoise=self.config.sampling.denoise) samples = all_samples[-1] for img in samples: img = inverse_data_transform(self.config, img) save_image(img, os.path.join(self.args.image_folder, 'image_{}.png'.format(img_id))) img_id += 1
def calculate_fid(self): import fid, pickle import tensorflow as tf stats_path = "fid_stats_cifar10_train.npz" # training set statistics inception_path = fid.check_or_download_inception( "./tmp/" ) # download inception network score = get_model(self.config) score = torch.nn.DataParallel(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm( range( self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000 ), desc="processing ckpt", ): states = torch.load( os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"), map_location=self.config.device, ) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() num_iters = ( self.config.fast_fid.num_samples // self.config.fast_fid.batch_size ) output_path = os.path.join(self.args.image_folder, "ckpt_{}".format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): init_samples = torch.rand( self.config.fast_fid.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device, ) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics( init_samples, score, sigmas, self.config.fast_fid.n_steps_each, self.config.fast_fid.step_lr, verbose=self.config.fast_fid.verbose, ) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view( self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) save_image( sample, os.path.join(output_path, "sample_{}.png".format(id)) ) # load precalculated training set statistics f = np.load(stats_path) mu_real, sigma_real = f["mu"][:], f["sigma"][:] f.close() fid.create_inception_graph( inception_path ) # load the graph into the current TF graph final_samples = ( (final_samples - final_samples.min()) / (final_samples.max() - final_samples.min()).data.cpu().numpy() * 255 ) final_samples = np.transpose(final_samples, [0, 2, 3, 1]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics( final_samples, sess, batch_size=100 ) fid_value = fid.calculate_frechet_distance( mu_gen, sigma_gen, mu_real, sigma_real ) print("FID: %s" % fid_value) with open(os.path.join(self.args.image_folder, "fids.pickle"), "wb") as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)