def save_grid(d_config, all_samples, gridsize, path, final_only, ckpt_id=None): """ Args: d_config: data config all_samples: all samples gridsize: how many images to save path: where to save the images final_only: whether to only save the final sample (true), or the sampling process (false) ckpt_id: checkpoint id """ griddim = int(np.sqrt(gridsize)) imdims = [d_config.channels, d_config.image_size, d_config.image_size] if final_only: sample = all_samples[-1].view(all_samples[-1].shape[0], *imdims) sample = inverse_data_transform(d_config, sample) save_image(make_grid(sample, griddim), fp=path + str(ckpt_id) + ".png") else: for i, sample in tqdm.tqdm(enumerate(all_samples), total=len(all_samples), desc="saving image samples"): sample = sample.view(sample.shape[0], *imdims) sample = inverse_data_transform(d_config, sample) save_image(make_grid(sample, griddim), fp=path + str(i) + ".png")
def sample_sbp(self, S, D): dataset, test_dataset = get_dataset(self.args, self.config) train_loader = data.DataLoader( dataset, batch_size=5, shuffle=True, num_workers=self.config.data.num_workers, ) for i, (x, y) in enumerate(train_loader): images = tvu.make_grid(x, nrow=5, padding=1, pad_value=1, normalize=False) tvu.save_image(images, os.path.join(self.args.image_folder, "reals.png")) break x = torch.zeros( 64, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.device, ) x = sbp_stage1(x, S, self.config, D, tau=self.args.tau, record=True) images = tvu.make_grid(inverse_data_transform(self.config, x), nrow=8, padding=1, pad_value=1, normalize=False) tvu.save_image(images, os.path.join(self.args.image_folder, "1-final.png")) x = sbp_stage2(x, S, self.config, sigma_sq=self.args.sigma_sq, record=True) images = tvu.make_grid(inverse_data_transform(self.config, x), nrow=8, padding=1, pad_value=1, normalize=False) tvu.save_image(images, os.path.join(self.args.image_folder, "2-final.png"))
def run(self): bs = self.config.sampling.batch_size dataloader = self.get_dataloader(bs=bs) sigmas = self.get_sigmas(npy=True) score = self.get_model() final_samples_denoised = None kwargs = {'sigmas': sigmas, 'nsigma': self.config.sampling.nsigma, 'step_lr': self.config.sampling.step_lr, 'final_only': True, 'target': self.args.target, 'noise_first': self.config.sampling.noise_first} output_path = self.args.image_folder output_path_denoised = self.args.image_folder_denoised os.makedirs(output_path, exist_ok=True) os.makedirs(output_path_denoised, exist_ok=True) os.makedirs(self.args.fid_folder, exist_ok=True) for ckpt in tqdm.tqdm(range(self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, self.config.training.snapshot_freq), desc="processing ckpt"): score = self._load_states(score) score.eval() kwargs['scorenet'] = score for k in range(self.config.fast_fid.num_samples // bs): final_samples, final_samples_denoised = self.sample(dataloader, saveimages=(k == 0), kwargs=kwargs, bs=bs, gridsize=100, ckpt_id=ckpt) sizes = [self.config.data.channels, self.config.data.image_size, self.config.data.image_size] for i, sample in enumerate(final_samples[0]): sample = inverse_data_transform(self.config.data, sample.view(*sizes)) save_image(sample, os.path.join(output_path, 'sample_{}.png'.format(i + k * bs))) if final_samples_denoised is not None: for i, sample in enumerate(final_samples_denoised[0]): sample = inverse_data_transform(self.config.data, sample.view(*sizes)) save_image(sample, os.path.join(output_path_denoised, 'sample_{}.png'.format(i + k * bs))) log_output = open(f"{self.args.fid_folder}/log_FID.txt", 'a+') stat_path = get_fid_stats_path(self.config.data, fid_stats_folder=self.args.exp) fid = get_fid(stat_path, output_path, bs=self.config.fast_fid.batch_size) print("(Samples) {} ckpt: {}, fid: {}".format(self.args.doc, ckpt, fid)) print("(Samples) {} ckpt: {}, fid: {}".format(self.args.doc, ckpt, fid), file=log_output) if final_samples_denoised is not None: fid_denoised = get_fid(stat_path, output_path_denoised, bs=self.config.fast_fid.batch_size) print("(Denoised samples) {} ckpt: {}, fid: {}".format(self.args.doc, ckpt, fid_denoised)) print("(Denoised samples) {} ckpt: {}, fid: {}".format(self.args.doc, ckpt, fid_denoised), file=log_output)
def sbp_stage2(x, S, config, sigma_sq, n_stages=1000, record=False, **kwargs): sigma = np.sqrt(sigma_sq) with torch.no_grad(): n = x.size(0) x_new = x.to('cuda') for k in range(n_stages): if record: if not k % (n_stages / 10): images = tvu.make_grid(inverse_data_transform( config, x_new), nrow=8, padding=1, pad_value=1, normalize=False) tvu.save_image( images, os.path.join("./exp/image_samples/images", "2-%06d.png" % k)) t = (torch.ones(n) * k).to(x.device) e = S(x_new + config.image_mean.to(x_new.device)[None, ...], t.float()) x0_from_e = x_new - sigma_sq * e / n_stages noise = torch.randn_like(x_new) x_new = x0_from_e + sigma * noise / np.sqrt(n_stages) if k == n_stages - 1: # denoise t = (torch.ones(n) * (n_stages - 1)).to(x.device) e = S(x_new + config.image_mean.to(x_new.device)[None, ...], t.float()) x0_from_e = x_new - sigma_sq * e / n_stages x_new = x0_from_e return x_new
def sample_fid(self, model): config = self.config img_id = len(glob.glob(f"{self.args.image_folder}/*")) print(f"starting from image {img_id}") total_n_samples = 50000 n_rounds = (total_n_samples - img_id) // config.sampling.batch_size with paddle.no_grad(): for _ in tqdm.tqdm( range(n_rounds), desc="Generating image samples for FID evaluation."): n = config.sampling.batch_size x = paddle.randn( n, config.data.channels, config.data.image_size, config.data.image_size, ) x = self.sample_image(x, model) x = inverse_data_transform(config, x) for i in range(n): Image.fromarray( np.uint8(x[i].numpy().transpose([1, 2, 0]) * 255)).save( os.path.join(self.args.image_folder, f"{img_id}.png")) img_id += 1
def sample_fid(self, S, D): img_id = len(glob.glob(f"{self.args.image_folder}/*")) print(f"starting from image {img_id}") total_n_samples = self.config.sampling.total_n_samples n = self.config.sampling.batch_size n_rounds = (total_n_samples - img_id) // n with torch.no_grad(): for _ in tqdm.tqdm( range(n_rounds), desc="Generating image samples for FID evaluation." ): x = torch.zeros( n, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.device, ) x = self.sample_image_sbp(x, S, D) x = inverse_data_transform(self.config, x) for i in range(n): tvu.save_image( x[i], os.path.join(self.args.image_folder, f"{img_id}.png") ) img_id += 1
def sample_interpolation(self, S): dataset, test_dataset = get_dataset(self.args, self.config) train_loader = data.DataLoader( dataset, batch_size=2, shuffle=True, num_workers=self.config.data.num_workers, ) for i, (x, y) in enumerate(train_loader): images = tvu.make_grid(x, nrow=2, padding=1, pad_value=1, normalize=False) tvu.save_image(images, os.path.join(self.args.image_folder, "reals.png")) break noise = torch.randn( 1, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.device, ).repeat(2, 1, 1, 1) * np.sqrt(0.1) x = noise + data_transform(self.config, x.to(self.device)) coef = torch.linspace(0, 1, 10).view(-1, 1, 1, 1) coef = coef.to(self.device) x = x[[1]] * coef + x[[0]] * (1 - coef) x = sbp_stage2_interpolation(x, S, self.config, sigma_sq=0.1, record=True) images = tvu.make_grid(inverse_data_transform(self.config, x), nrow=10, padding=1, pad_value=1, normalize=False) tvu.save_image(images, os.path.join(self.args.image_folder, "2-final.png"))
def sample_inpainting(self, S): dataset, test_dataset = get_dataset(self.args, self.config) train_loader = data.DataLoader( dataset, batch_size=4, shuffle=True, num_workers=self.config.data.num_workers, ) for i, (x, y) in enumerate(train_loader): images = tvu.make_grid(x, nrow=1, padding=1, pad_value=1, normalize=False) tvu.save_image(images, os.path.join(self.args.image_folder, "reals.png")) break mask = torch.zeros( 4, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.device, ) mask[:, :, :, :16] += 1 # 0 for missing pixels x_occluded = x.to(self.device) * mask images = tvu.make_grid(x_occluded, nrow=1, padding=1, pad_value=1, normalize=False) tvu.save_image(images, os.path.join(self.args.image_folder, "2-occluded.png")) torch.manual_seed(1) x = data_transform(self.config, x.to(self.device)) x = sbp_stage2_inpainting(x, mask, S, self.config, sigma_sq=self.args.sigma_sq, record=True) images = tvu.make_grid(inverse_data_transform(self.config, x), nrow=1, padding=1, pad_value=1, normalize=False) tvu.save_image(images, os.path.join(self.args.image_folder, "2-final.png"))
def fast_ensemble_fid(self): from evaluation.fid_score import get_fid, get_fid_stats_path import pickle num_ensembles = 5 scores = [NCSN(self.config).to(self.config.device) for _ in range(num_ensembles)] scores = [torch.nn.DataParallel(score) for score in scores] sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm(range(self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000), desc="processing ckpt"): begin_ckpt = max(self.config.fast_fid.begin_ckpt, ckpt - (num_ensembles - 1) * 5000) index = 0 for i in range(begin_ckpt, ckpt + 5000, 5000): states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{i}.pth'), map_location=self.config.device) scores[index].load_state_dict(states[0]) scores[index].eval() index += 1 def scorenet(x, labels): num_ckpts = (ckpt - begin_ckpt) // 5000 + 1 return sum([scores[i](x, labels) for i in range(num_ckpts)]) / num_ckpts num_iters = self.config.fast_fid.num_samples // self.config.fast_fid.batch_size output_path = os.path.join(self.args.image_folder, 'ckpt_{}'.format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): init_samples = torch.rand(self.config.fast_fid.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics(init_samples, scorenet, sigmas, self.config.fast_fid.n_steps_each, self.config.fast_fid.step_lr, verbose=self.config.fast_fid.verbose, denoise=self.config.sampling.denoise) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view(self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) save_image(sample, os.path.join(output_path, 'sample_{}.png'.format(id))) stat_path = get_fid_stats_path(self.args, self.config, download=True) fid = get_fid(stat_path, output_path) fids[ckpt] = fid print("ckpt: {}, fid: {}".format(ckpt, fid)) with open(os.path.join(self.args.image_folder, 'fids.pickle'), 'wb') as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)
def run(self): sigmas = self.get_sigmas(npy=True) score = self.load_score(eval=True) classifier = Net() classifier.load_state_dict(torch.load('./evaluation/mnist_cnn.pt')) classifier = classifier.cuda() targets = np.zeros(1000, dtype=np.int32) targets_denoised = np.zeros(1000, dtype=np.int32) kwargs = { 'scorenet': score, 'sigmas': sigmas, 'nsigma': self.config.sampling.nsigma, 'step_lr': self.config.sampling.step_lr, 'final_only': True, 'save_freq': self.config.sampling.save_freq, 'target': self.args.target, 'noise_first': self.config.sampling.noise_first } bs = self.config.fast_fid.batch_size for k in range(self.config.fast_fid.num_samples // bs): all_samples, all_samples_denoised = self.sample(None, saveimages=False, kwargs=kwargs, bs=bs) img = inverse_data_transform(self.config.data, all_samples[-1]) targets = compute_target(img, classifier, targets) img = inverse_data_transform(self.config.data, all_samples[-1]) targets_denoised = compute_target(img, classifier, targets) covered_targets, Kl_score = compute_score( targets, self.config.fast_fid.num_samples) str_ = " {} || lr {} n_step_each {} | Covered Targets:{}, KL Score:{}]" o = str_.format(self.args.doc, self.config.sampling.step_lr, self.config.sampling.nsigma, covered_targets, Kl_score) print(o) print(o, file=open(f"{self.args.fid_folder}/log_FID.txt", 'a+'))
def sample_sequence(self, model): config = self.config x = paddle.randn([ 8, config.data.channels, config.data.image_size, config.data.image_size, ]) # NOTE: This means that we are producing each predicted x0, not x_{t-1} at timestep t. with paddle.no_grad(): _, x = self.sample_image(x, model, last=False) x = [inverse_data_transform(config, y) for y in x] for i in range(len(x)): for j in range(x[i].shape[0]): Image.fromarray( np.uint8(x[i][j].numpy().transpose([1, 2, 0]) * 255)).save( os.path.join(self.args.image_folder, f"{j}_{i}.png"))
def sample_interpolation(self, model): config = self.config def slerp(z1, z2, alpha): theta = paddle.acos( paddle.sum(z1 * z2) / (paddle.norm(z1) * paddle.norm(z2))) return (paddle.sin((1 - alpha) * theta) / paddle.sin(theta) * z1 + paddle.sin(alpha * theta) / paddle.sin(theta) * z2) z1 = paddle.randn( 1, config.data.channels, config.data.image_size, config.data.image_size, ) z2 = paddle.randn( 1, config.data.channels, config.data.image_size, config.data.image_size, ) alpha = paddle.arange(0.0, 1.01, 0.1) z_ = [] for i in range(alpha.shape[0]): z_.append(slerp(z1, z2, alpha[i])) x = paddle.concat(z_, 0) xs = [] # Hard coded here, modify to your preferences with paddle.no_grad(): for i in range(0, x.shape[0], 8): xs.append(self.sample_image(x[i:i + 8], model)) x = inverse_data_transform(config, paddle.concat(xs, 0)) for i in range(x.shape[0]): Image.fromarray(np.uint8( x[i].numpy().transpose([1, 2, 0]) * 255)).save( os.path.join(self.args.image_folder, f"{i}.png"))
def sample(self): source_dataset, _ = get_dataset(self.args, self.config.source) baryproj = get_bary(self.config) baryproj.eval() if self.config.sampling.ckpt_id is None: bp_states = torch.load(os.path.join(self.args.log_path, 'checkpoint.pth'), map_location=self.config.device) else: bp_states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{self.config.compatibility.ckpt_id}.pth'), map_location=self.config.device) baryproj.load_state_dict(bp_states[0]) if(not self.config.sampling.fid): dataloader = DataLoader(source_dataset, batch_size=self.config.sampling.batch_size, shuffle=True, num_workers=self.config.source.data.num_workers) batch_samples = [] for i in range(self.config.sampling.n_batches): (Xs, _) = next(iter(dataloader)) Xs = data_transform(self.config.source, Xs) transport = baryproj(Xs) batch_samples.append(inverse_data_transform(self.config, transport)) sample = torch.cat(batch_samples, dim=0) image_grid = make_grid(sample[:min(64, len(sample))], nrow=8) save_image(image_grid, os.path.join(self.args.image_folder, 'sample_grid.png')) source_grid = make_grid(Xs[:min(64, len(Xs))], nrow=8) save_image(source_grid, os.path.join(self.args.image_folder, 'source_grid.png')) np.save(os.path.join(self.args.image_folder, 'sample.npy'), sample.detach().cpu().numpy()) np.save(os.path.join(self.args.image_folder, 'sources.npy'), Xs.detach().cpu().numpy()) else: batch_size = self.config.sampling.samples_per_batch total_n_samples = self.config.sampling.num_samples4fid n_rounds = total_n_samples // batch_size dataloader = DataLoader(source_dataset, batch_size=self.config.sampling.samples_per_batch, shuffle=True, num_workers=self.config.source.data.num_workers) data_iter = iter(dataloader) img_id = 0 for _ in tqdm(range(n_rounds), desc='Generating image samples for FID/inception score evaluation'): with torch.no_grad(): (Xs, _) = next(data_iter) Xs = data_transform(self.config.source, Xs).to(self.config.device) transport = baryproj(Xs) for img in transport: img = inverse_data_transform(self.config.target, img) save_image(img, os.path.join(self.args.image_folder, 'image_{}.png'.format(img_id))) img_id += 1 del Xs del transport
def fast_ensemble_fid(self): from ncsn.evaluation.fid_score import get_fid, get_fid_stats_path import pickle num_ensembles = 5 scores = [ NCSN(self.config.ncsn).to(self.config.device) for _ in range(num_ensembles) ] scores = [torch.nn.DataParallel(score) for score in scores] sigmas_th = get_sigmas(self.config.ncsn) sigmas = sigmas_th.cpu().numpy() if self.config.compatibility.ckpt_id is None: cpat_states = torch.load(os.path.join( 'scones', self.config.compatibility.log_path, 'checkpoint.pth'), map_location=self.config.device) else: cpat_states = torch.load(os.path.join( 'scones', self.config.compatibility.log_path, f'checkpoint_{self.config.compatibility.ckpt_id}.pth'), map_location=self.config.device) cpat = get_compatibility(self.config) cpat.load_state_dict(cpat_states[0]) source_dataset, _ = get_dataset(self.args, self.config.source) source_dataloader = DataLoader( source_dataset, batch_size=self.config.ncsn.sampling.sources_per_batch, shuffle=True, num_workers=self.config.source.data.num_workers) source_iter = iter(source_dataloader) fids = {} for ckpt in tqdm.tqdm(range(self.config.ncsn.fast_fid.begin_ckpt, self.config.ncsn.fast_fid.end_ckpt + 1, 5000), desc="processing ckpt"): begin_ckpt = max(self.config.ncsn.fast_fid.begin_ckpt, ckpt - (num_ensembles - 1) * 5000) index = 0 for i in range(begin_ckpt, ckpt + 5000, 5000): states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{i}.pth'), map_location=self.config.device) scores[index].load_state_dict(states[0]) scores[index].eval() index += 1 def scorenet(x, labels): num_ckpts = (ckpt - begin_ckpt) // 5000 + 1 return sum([scores[i](x, labels) for i in range(num_ckpts)]) / num_ckpts num_iters = self.config.ncsn.fast_fid.num_samples // self.config.ncsn.fast_fid.batch_size output_path = os.path.join(self.args.image_folder, 'ckpt_{}'.format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): try: (Xs, _) = next(source_iter) Xs_global = torch.cat( [Xs] * self.config.ncsn.sampling.samples_per_source, dim=0).to(self.config.device) except StopIteration: source_iter = iter(source_dataloader) (Xs, _) = next(source_iter) Xs_global = torch.cat( [Xs] * self.config.ncsn.sampling.samples_per_source, dim=0).to(self.config.device) init_samples = torch.rand(self.config.ncsn.fast_fid.batch_size, self.config.target.data.channels, self.config.target.data.image_size, self.config.target.data.image_size, device=self.config.device) init_samples = data_transform(self.config.target, init_samples) init_samples.requires_grad = True init_samples = init_samples.to(self.config.device) all_samples = anneal_Langevin_dynamics( init_samples, Xs_global, scorenet, cpat, sigmas, self.config.ncsn.fast_fid.n_steps_each, self.config.ncsn.fast_fid.step_lr, verbose=self.config.ncsn.fast_fid.verbose, final_only=self.config.ncsn.sampling.final_only, denoise=self.config.ncsn.sampling.denoise) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view(self.config.target.data.channels, self.config.target.data.image_size, self.config.target.data.image_size) sample = inverse_data_transform(self.config.target, sample) save_image( sample, os.path.join(output_path, 'sample_{}.png'.format(id))) stat_path = get_fid_stats_path(self.args, self.config.ncsn, download=True) fid = get_fid(stat_path, output_path) fids[ckpt] = fid print("ckpt: {}, fid: {}".format(ckpt, fid)) with open(os.path.join(self.args.image_folder, 'fids.pickle'), 'wb') as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)
def sample(self): if self.config.ncsn.sampling.ckpt_id is None: ncsn_states = torch.load(os.path.join( 'scones', self.config.ncsn.sampling.log_path, 'checkpoint.pth'), map_location=self.config.device) else: ncsn_states = torch.load(os.path.join( 'scones', self.config.ncsn.sampling.log_path, f'checkpoint_{self.config.ncsn.sampling.ckpt_id}.pth'), map_location=self.config.device) score = get_scorenet(self.config) score = torch.nn.DataParallel(score) sigmas_th = get_sigmas(self.config.ncsn) sigmas = sigmas_th.cpu().numpy() if ("module.sigmas" in ncsn_states[0].keys()): ncsn_states[0]["module.sigmas"] = sigmas_th score.load_state_dict(ncsn_states[0], strict=True) score.eval() baryproj_data_init = (hasattr(self.config, "baryproj") and self.config.ncsn.sampling.data_init) if (baryproj_data_init): if (self.config.baryproj.ckpt_id is None): bproj_states = torch.load(os.path.join( 'scones', self.config.baryproj.log_path, 'checkpoint.pth'), map_location=self.config.device) else: bproj_states = torch.load(os.path.join( 'scones', self.config.baryproj.log_path, f'checkpoint_{self.config.baryproj.ckpt_id}.pth'), map_location=self.config.device) bproj = get_bary(self.config) bproj.load_state_dict(bproj_states[0]) bproj = torch.nn.DataParallel(bproj) bproj.eval() if self.config.compatibility.ckpt_id is None: cpat_states = torch.load(os.path.join( 'scones', self.config.compatibility.log_path, 'checkpoint.pth'), map_location=self.config.device) else: cpat_states = torch.load(os.path.join( 'scones', self.config.compatibility.log_path, f'checkpoint_{self.config.compatibility.ckpt_id}.pth'), map_location=self.config.device) cpat = get_compatibility(self.config) cpat.load_state_dict(cpat_states[0]) if self.config.ncsn.model.ema: ema_helper = EMAHelper(mu=self.config.ncsn.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(ncsn_states[-1]) ema_helper.ema(score) source_dataset, _ = get_dataset(self.args, self.config.source) dataloader = DataLoader( source_dataset, batch_size=self.config.ncsn.sampling.sources_per_batch, shuffle=True, num_workers=self.config.source.data.num_workers) data_iter = iter(dataloader) (Xs, labels) = next(data_iter) Xs_global = torch.cat([Xs] * self.config.ncsn.sampling.samples_per_source, dim=0).to(self.config.device) Xs_global = data_transform(self.config.source, Xs_global) if (hasattr(self.config.ncsn.sampling, "n_sigmas_skip")): n_sigmas_skip = self.config.ncsn.sampling.n_sigmas_skip else: n_sigmas_skip = 0 if not self.config.ncsn.sampling.fid: if self.config.ncsn.sampling.inpainting: ''' NCSN INPAINTING CODE. EITHER PATCH THIS FOR SCONES OR REMOVE IT. data_iter = iter(dataloader) refer_images, _ = next(data_iter) refer_images = refer_images.to(self.config.device) width = int(np.sqrt(self.config.sampling.batch_size)) init_samples = torch.rand(width, width, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics_inpainting(init_samples, refer_images[:width, ...], score, sigmas, self.config.data.image_size, self.config.sampling.n_steps_each, self.config.sampling.step_lr) torch.save(refer_images[:width, ...], os.path.join(self.args.image_folder, 'refer_image.pth')) refer_images = refer_images[:width, None, ...].expand(-1, width, -1, -1, -1).reshape(-1, *refer_images.shape[ 1:]) save_image(refer_images, os.path.join(self.args.image_folder, 'refer_image.png'), nrow=width) if not self.config.sampling.final_only: for i, sample in enumerate(tqdm.tqdm(all_samples)): sample = sample.view(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size))) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i))) torch.save(sample, os.path.join(self.args.image_folder, 'completion_{}.pth'.format(i))) else: sample = all_samples[-1].view(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size))) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(self.config.ncsn.sampling.ckpt_id))) torch.save(sample, os.path.join(self.args.image_folder, 'completion_{}.pth'.format(self.config.sampling.ckpt_id))) ''' raise NotImplementedError( "Inpainting with SCONES is not currently implemented.") elif self.config.ncsn.sampling.interpolation: ''' NCSN INTERPOLATION CODE. EITHER PATCH THIS FOR SCONES OR REMOVE IT. if self.config.sampling.data_init: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) init_samples = samples + sigmas_th[0] * torch.randn_like(samples) else: init_samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics_interpolation(init_samples, score, sigmas, self.config.sampling.n_interpolations, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=True, final_only=self.config.sampling.final_only) if not self.config.sampling.final_only: for i, sample in tqdm.tqdm(enumerate(all_samples), total=len(all_samples), desc="saving image samples"): sample = sample.view(sample.shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, nrow=self.config.sampling.n_interpolations) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i))) torch.save(sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(i))) else: sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, self.config.sampling.n_interpolations) save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(self.config.sampling.ckpt_id))) torch.save(sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(self.config.sampling.ckpt_id))) ''' raise NotImplementedError( "Interpolation with SCONES is not currently implemented.") else: if self.config.ncsn.sampling.data_init: if (baryproj_data_init): with torch.no_grad(): init_Xt = (bproj(Xs_global) + sigmas_th[n_sigmas_skip] * torch.randn_like(Xs_global)).detach() else: init_Xt = Xs_global + sigmas_th[ n_sigmas_skip] * torch.randn_like(Xs_global) init_Xt.requires_grad = True init_Xt = init_Xt.to(self.config.device) else: init_Xt = torch.rand( self.config.ncsn.sampling.sources_per_batch * self.config.ncsn.sampling.samples_per_source, self.config.target.data.channels, self.config.target.data.image_size, self.config.target.data.image_size, device=self.config.device) init_Xt = data_transform(self.config.target, init_Xt) init_Xt.requires_grad = True init_Xt = init_Xt.to(self.config.device) all_samples = anneal_Langevin_dynamics( init_Xt, Xs_global, score, cpat, sigmas, self.config.ncsn.sampling.n_steps_each, self.config.ncsn.sampling.step_lr, verbose=True, final_only=self.config.ncsn.sampling.final_only, denoise=self.config.ncsn.sampling.denoise, n_sigmas_skip=n_sigmas_skip) all_samples = torch.stack(all_samples, dim=0) if not self.config.ncsn.sampling.final_only: all_samples = all_samples.view( (-1, self.config.ncsn.sampling.sources_per_batch, self.config.ncsn.sampling.samples_per_source, self.config.target.data.channels, self.config.target.data.image_size, self.config.target.data.image_size)) np.save( os.path.join(self.args.image_folder, 'all_samples.npy'), all_samples.detach().cpu().numpy()) sample = all_samples[-1].view( self.config.ncsn.sampling.sources_per_batch * self.config.ncsn.sampling.samples_per_source, self.config.target.data.channels, self.config.target.data.image_size, self.config.target.data.image_size) sample = inverse_data_transform(self.config.target, sample) image_grid = make_grid( sample, nrow=self.config.ncsn.sampling.sources_per_batch) save_image( image_grid, os.path.join(self.args.image_folder, 'sample_grid.png')) source_grid = make_grid( Xs, nrow=self.config.ncsn.sampling.sources_per_batch) save_image( source_grid, os.path.join(self.args.image_folder, 'source_grid.png')) bproj_of_source = make_grid( bproj(Xs), nrow=self.config.ncsn.sampling.sources_per_batch) save_image( bproj_of_source, os.path.join(self.args.image_folder, 'bproj_sources.png')) np.save(os.path.join(self.args.image_folder, 'sources.npy'), Xs.detach().cpu().numpy()) np.save( os.path.join(self.args.image_folder, 'source_labels.npy'), labels.detach().cpu().numpy()) np.save(os.path.join(self.args.image_folder, 'bproj.npy'), bproj(Xs).detach().cpu().numpy()) np.save(os.path.join(self.args.image_folder, 'samples.npy'), sample.detach().cpu().numpy()) else: batch_size = self.config.ncsn.sampling.sources_per_batch * self.config.ncsn.sampling.samples_per_source total_n_samples = self.config.ncsn.sampling.num_samples4fid n_rounds = total_n_samples // batch_size if self.config.ncsn.sampling.data_init: dataloader = DataLoader( source_dataset, batch_size=self.config.ncsn.sampling.sources_per_batch, shuffle=True, num_workers=self.config.source.data.num_workers) data_iter = iter(dataloader) img_id = 0 for r in tqdm.tqdm( range(n_rounds), desc= 'Generating image samples for FID/inception score evaluation' ): if self.config.ncsn.sampling.data_init: try: init_samples, labels = next(data_iter) init_samples = torch.cat( [init_samples] * self.config.ncsn.sampling.samples_per_source, dim=0) labels = torch.cat( [labels] * self.config.ncsn.sampling.samples_per_source, dim=0) except StopIteration: data_iter = iter(dataloader) init_samples, labels = next(data_iter) init_samples = torch.cat( [init_samples] * self.config.ncsn.sampling.samples_per_source, dim=0) labels = torch.cat( [labels] * self.config.ncsn.sampling.samples_per_source, dim=0) init_samples = init_samples.to(self.config.device) init_samples = data_transform(self.config.target, init_samples) if (baryproj_data_init): with torch.no_grad(): bproj_samples = bproj(init_samples).detach() else: bproj_samples = torch.clone(init_samples).detach() samples = bproj_samples + sigmas_th[ n_sigmas_skip] * torch.randn_like(bproj_samples) samples.requires_grad = True samples = samples.to(self.config.device) else: samples = torch.rand(batch_size, self.config.target.data.channels, self.config.target.data.image_size, self.config.target.data.image_size, device=self.config.device) init_samples = torch.clone(samples) samples = data_transform(self.config.target, samples) samples.requires_grad = True samples = samples.to(self.config.device) all_samples = anneal_Langevin_dynamics( samples, Xs_global, score, cpat, sigmas, self.config.ncsn.sampling.n_steps_each, self.config.ncsn.sampling.step_lr, verbose=True, final_only=self.config.ncsn.sampling.final_only, denoise=self.config.ncsn.sampling.denoise, n_sigmas_skip=n_sigmas_skip) samples = all_samples[-1] for img in samples: img = inverse_data_transform(self.config.target, img) save_image( img, os.path.join(self.args.image_folder, 'image_{}.png'.format(img_id))) img_id += 1 if (self.args.save_labels): save_path = os.path.join(self.args.image_folder, 'labels') np.save(os.path.join(save_path, f'sources_{r}.npy'), init_samples.detach().cpu().numpy()) np.save(os.path.join(save_path, f'source_labels_{r}.npy'), labels.detach().cpu().numpy()) np.save(os.path.join(save_path, f"bproj_{r}.npy"), bproj_samples.detach().cpu().numpy()) np.save(os.path.join(save_path, f"samples_{r}.npy"), samples.detach().cpu().numpy())
def calculate_fid(self): import fid, pickle import tensorflow as tf stats_path = "fid_stats_cifar10_train.npz" # training set statistics inception_path = fid.check_or_download_inception( "./tmp/" ) # download inception network score = get_model(self.config) score = torch.nn.DataParallel(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm( range( self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000 ), desc="processing ckpt", ): states = torch.load( os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"), map_location=self.config.device, ) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() num_iters = ( self.config.fast_fid.num_samples // self.config.fast_fid.batch_size ) output_path = os.path.join(self.args.image_folder, "ckpt_{}".format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): init_samples = torch.rand( self.config.fast_fid.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device, ) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics( init_samples, score, sigmas, self.config.fast_fid.n_steps_each, self.config.fast_fid.step_lr, verbose=self.config.fast_fid.verbose, ) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view( self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) save_image( sample, os.path.join(output_path, "sample_{}.png".format(id)) ) # load precalculated training set statistics f = np.load(stats_path) mu_real, sigma_real = f["mu"][:], f["sigma"][:] f.close() fid.create_inception_graph( inception_path ) # load the graph into the current TF graph final_samples = ( (final_samples - final_samples.min()) / (final_samples.max() - final_samples.min()).data.cpu().numpy() * 255 ) final_samples = np.transpose(final_samples, [0, 2, 3, 1]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics( final_samples, sess, batch_size=100 ) fid_value = fid.calculate_frechet_distance( mu_gen, sigma_gen, mu_real, sigma_real ) print("FID: %s" % fid_value) with open(os.path.join(self.args.image_folder, "fids.pickle"), "wb") as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)
def sbp_stage1(x, S, config, D, tau, n_stages=1000, record=False, **kwargs): m = 1 with torch.no_grad(): n = x.size(0) x_new = x.to('cuda') t = (torch.ones(n * m) * 0).to(x.device) for k in range(n_stages): if record: if not k % (n_stages / 10): images = tvu.make_grid(inverse_data_transform( config, x_new), nrow=8, padding=1, pad_value=1, normalize=False) tvu.save_image( images, os.path.join("./exp/image_samples/images", "1-%06d.png" % k)) t_k = k / n_stages coef = np.sqrt(tau) * np.sqrt(1 - t_k) z1 = torch.randn(m, n, x_new.shape[1], x_new.shape[2], x_new.shape[3], dtype=torch.float32, device=x_new.device) interpolation1 = x_new.view(1, n, x_new.shape[1], x_new.shape[2], x_new.shape[3]) + coef * z1 interpolation1 = interpolation1.view(-1, x_new.shape[1], x_new.shape[2], x_new.shape[3]) density_ratio1 = torch.exp(D(interpolation1).detach()).view( m, n, 1, 1, 1) z2 = torch.randn(m, n, x_new.shape[1], x_new.shape[2], x_new.shape[3], dtype=torch.float32, device=x_new.device) interpolation2 = x_new.view(1, n, x_new.shape[1], x_new.shape[2], x_new.shape[3]) + coef * z2 interpolation2 = interpolation2.view(-1, x_new.shape[1], x_new.shape[2], x_new.shape[3]) density_ratio2 = torch.exp(D(interpolation2).detach()).view( m, n, 1, 1, 1) output = S( interpolation1 + config.image_mean.to(x_new.device)[None, ...], t.float()) e = output.view(m, n, x_new.shape[1], x_new.shape[2], x_new.shape[3]) b = torch.mean( (-e + coef * z1 / tau) * density_ratio1, dim=0) / torch.mean( density_ratio2, dim=0) + x_new / tau x0_from_e = x_new + tau * b / n_stages noise = torch.randn_like(x_new) x_new = x0_from_e + np.sqrt(tau) * noise / np.sqrt(n_stages) # Tweedie's formula if record: t = (torch.ones(n) * 0).to(x.device) e = S(x_new + config.image_mean.to(x_new.device)[None, ...], t.float()) x0_from_e = x_new - e images = tvu.make_grid(inverse_data_transform(config, x0_from_e), nrow=8, padding=1, pad_value=1, normalize=False) tvu.save_image( images, os.path.join("./exp/image_samples/images", "1-tweedie.png")) return x_new
def fast_fid(self): ### Test the fids of ensembled checkpoints. ### Shouldn't be used for pretrained with ema if self.config.fast_fid.ensemble: if self.config.model.ema: raise RuntimeError( "Cannot apply ensembling to pretrained with EMA.") self.fast_ensemble_fid() return from ncsn.evaluation.fid_score import get_fid, get_fid_stats_path import pickle score = get_model(self.config) score = torch.nn.DataParallel(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm(range(self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000), desc="processing ckpt"): states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{ckpt}.pth'), map_location=self.config.device) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() num_iters = self.config.fast_fid.num_samples // self.config.fast_fid.batch_size output_path = os.path.join(self.args.image_folder, 'ckpt_{}'.format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): init_samples = torch.rand(self.config.fast_fid.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics( init_samples, score, sigmas, self.config.fast_fid.n_steps_each, self.config.fast_fid.step_lr, verbose=self.config.fast_fid.verbose, denoise=self.config.sampling.denoise) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view(self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) save_image( sample, os.path.join(output_path, 'sample_{}.png'.format(id))) stat_path = get_fid_stats_path(self.args, self.config, download=True) fid = get_fid(stat_path, output_path) fids[ckpt] = fid print("ckpt: {}, fid: {}".format(ckpt, fid)) with open(os.path.join(self.args.image_folder, 'fids.pickle'), 'wb') as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)
def train(self): dataset, test_dataset = get_dataset(self.args, self.config) dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=self.config.data.num_workers) test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=self.config.data.num_workers, drop_last=True) test_iter = iter(test_loader) self.config.input_dim = self.config.data.image_size**2 * self.config.data.channels tb_logger = self.config.tb_logger score = get_model(self.config) score = torch.nn.DataParallel(score) optimizer = get_optimizer(self.config, score.parameters()) start_epoch = 0 step = 0 if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) if self.args.resume_training: states = torch.load( os.path.join(self.args.log_path, 'checkpoint.pth')) score.load_state_dict(states[0]) ### Make sure we can resume with different eps states[1]['param_groups'][0]['eps'] = self.config.optim.eps optimizer.load_state_dict(states[1]) start_epoch = states[2] step = states[3] if self.config.model.ema: ema_helper.load_state_dict(states[4]) sigmas = get_sigmas(self.config) if self.config.training.log_all_sigmas: ### Commented out training time logging to save time. test_loss_per_sigma = [None for _ in range(len(sigmas))] def hook(loss, labels): # for i in range(len(sigmas)): # if torch.any(labels == i): # test_loss_per_sigma[i] = torch.mean(loss[labels == i]) pass def tb_hook(): # for i in range(len(sigmas)): # if test_loss_per_sigma[i] is not None: # tb_logger.add_scalar('test_loss_sigma_{}'.format(i), test_loss_per_sigma[i], # global_step=step) pass def test_hook(loss, labels): for i in range(len(sigmas)): if torch.any(labels == i): test_loss_per_sigma[i] = torch.mean(loss[labels == i]) def test_tb_hook(): for i in range(len(sigmas)): if test_loss_per_sigma[i] is not None: tb_logger.add_scalar('test_loss_sigma_{}'.format(i), test_loss_per_sigma[i], global_step=step) else: hook = test_hook = None def tb_hook(): pass def test_tb_hook(): pass for epoch in range(start_epoch, self.config.training.n_epochs): for i, (X, y) in enumerate(dataloader): score.train() step += 1 X = X.to(self.config.device) X = data_transform(self.config, X) loss = anneal_dsm_score_estimation( score, X, sigmas, None, self.config.training.anneal_power, hook) tb_logger.add_scalar('loss', loss, global_step=step) tb_hook() logging.info("step: {}, loss: {}".format(step, loss.item())) optimizer.zero_grad() loss.backward() optimizer.step() if self.config.model.ema: ema_helper.update(score) if step >= self.config.training.n_iters: return 0 if step % 100 == 0: if self.config.model.ema: test_score = ema_helper.ema_copy(score) else: test_score = score test_score.eval() try: test_X, test_y = next(test_iter) except StopIteration: test_iter = iter(test_loader) test_X, test_y = next(test_iter) test_X = test_X.to(self.config.device) test_X = data_transform(self.config, test_X) with torch.no_grad(): test_dsm_loss = anneal_dsm_score_estimation( test_score, test_X, sigmas, None, self.config.training.anneal_power, hook=test_hook) tb_logger.add_scalar('test_loss', test_dsm_loss, global_step=step) test_tb_hook() logging.info("step: {}, test_loss: {}".format( step, test_dsm_loss.item())) del test_score if step % self.config.training.snapshot_freq == 0: states = [ score.state_dict(), optimizer.state_dict(), epoch, step, ] if self.config.model.ema: states.append(ema_helper.state_dict()) torch.save( states, os.path.join(self.args.log_path, 'checkpoint_{}.pth'.format(step))) torch.save( states, os.path.join(self.args.log_path, 'checkpoint.pth')) if self.config.training.snapshot_sampling: if self.config.model.ema: test_score = ema_helper.ema_copy(score) else: test_score = score test_score.eval() ## Different part from NeurIPS 2019. ## Random state will be affected because of sampling during training time. init_samples = torch.rand(36, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics( init_samples, test_score, sigmas.cpu().numpy(), self.config.sampling.n_steps_each, self.config.sampling.step_lr, final_only=True, verbose=True, denoise=self.config.sampling.denoise) sample = all_samples[-1].view( all_samples[-1].shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, 6) save_image( image_grid, os.path.join(self.args.log_sample_path, 'image_grid_{}.png'.format(step))) torch.save( sample, os.path.join(self.args.log_sample_path, 'samples_{}.pth'.format(step))) del test_score del all_samples
def sample(self, return_NCSN=False): if self.config.sampling.ckpt_id is None: states = torch.load(os.path.join(self.args.log_path, 'checkpoint.pth'), map_location=self.config.device) else: states = torch.load(os.path.join( self.args.log_path, f'checkpoint_{self.config.sampling.ckpt_id}.pth'), map_location=self.config.device) score = get_model(self.config) score = torch.nn.DataParallel(score) score.load_state_dict(states[0], strict=True) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() dataset, _ = get_dataset(self.args, self.config) dataloader = DataLoader(dataset, batch_size=self.config.sampling.batch_size, shuffle=True, num_workers=4) score.eval() if (return_NCSN): return score if not self.config.sampling.fid: if self.config.sampling.inpainting: data_iter = iter(dataloader) refer_images, _ = next(data_iter) refer_images = refer_images.to(self.config.device) width = int(np.sqrt(self.config.sampling.batch_size)) init_samples = torch.rand(width, width, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics_inpainting( init_samples, refer_images[:width, ...], score, sigmas, self.config.data.image_size, self.config.sampling.n_steps_each, self.config.sampling.step_lr) torch.save( refer_images[:width, ...], os.path.join(self.args.image_folder, 'refer_image.pth')) refer_images = refer_images[:width, None, ...].expand( -1, width, -1, -1, -1).reshape(-1, *refer_images.shape[1:]) save_image(refer_images, os.path.join(self.args.image_folder, 'refer_image.png'), nrow=width) if not self.config.sampling.final_only: for i, sample in enumerate(tqdm.tqdm(all_samples)): sample = sample.view(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, int(np.sqrt(self.config.sampling.batch_size))) save_image( image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i))) torch.save( sample, os.path.join(self.args.image_folder, 'completion_{}.pth'.format(i))) else: sample = all_samples[-1].view( self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, int(np.sqrt(self.config.sampling.batch_size))) save_image( image_grid, os.path.join( self.args.image_folder, 'image_grid_{}.png'.format( self.config.sampling.ckpt_id))) torch.save( sample, os.path.join( self.args.image_folder, 'completion_{}.pth'.format( self.config.sampling.ckpt_id))) elif self.config.sampling.interpolation: if self.config.sampling.data_init: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) init_samples = samples + sigmas_th[0] * torch.randn_like( samples) else: init_samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics_interpolation( init_samples, score, sigmas, self.config.sampling.n_interpolations, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=True, final_only=self.config.sampling.final_only) if not self.config.sampling.final_only: for i, sample in tqdm.tqdm(enumerate(all_samples), total=len(all_samples), desc="saving image samples"): sample = sample.view(sample.shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, nrow=self.config.sampling.n_interpolations) save_image( image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i))) torch.save( sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(i))) else: sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, self.config.sampling.n_interpolations) save_image( image_grid, os.path.join( self.args.image_folder, 'image_grid_{}.png'.format( self.config.sampling.ckpt_id))) torch.save( sample, os.path.join( self.args.image_folder, 'samples_{}.pth'.format( self.config.sampling.ckpt_id))) else: if self.config.sampling.data_init: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) init_samples = samples + sigmas_th[0] * torch.randn_like( samples) else: init_samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics( init_samples, score, sigmas, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=True, final_only=self.config.sampling.final_only, denoise=self.config.sampling.denoise) if not self.config.sampling.final_only: for i, sample in tqdm.tqdm(enumerate(all_samples), total=len(all_samples), desc="saving image samples"): sample = sample.view(sample.shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, int(np.sqrt(self.config.sampling.batch_size))) save_image( image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i))) torch.save( sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(i))) else: sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, int(np.sqrt(self.config.sampling.batch_size))) save_image( image_grid, os.path.join( self.args.image_folder, 'image_grid_{}.png'.format( self.config.sampling.ckpt_id))) torch.save( sample, os.path.join( self.args.image_folder, 'samples_{}.pth'.format( self.config.sampling.ckpt_id))) else: total_n_samples = self.config.sampling.num_samples4fid n_rounds = total_n_samples // self.config.sampling.batch_size if self.config.sampling.data_init: dataloader = DataLoader( dataset, batch_size=self.config.sampling.batch_size, shuffle=True, num_workers=4) data_iter = iter(dataloader) img_id = 0 for _ in tqdm.tqdm( range(n_rounds), desc= 'Generating image samples for FID/inception score evaluation' ): if self.config.sampling.data_init: try: samples, _ = next(data_iter) except StopIteration: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) samples = samples + sigmas_th[0] * torch.randn_like( samples) else: samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) samples = data_transform(self.config, samples) all_samples = anneal_Langevin_dynamics( samples, score, sigmas, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=False, denoise=self.config.sampling.denoise) samples = all_samples[-1] for img in samples: img = inverse_data_transform(self.config, img) save_image( img, os.path.join(self.args.image_folder, 'image_{}.png'.format(img_id))) img_id += 1
def fast_fid(self): ### Test the fids of ensembled checkpoints. ### Shouldn't be used for pretrained with ema if self.config.ncsn.fast_fid.ensemble: if self.config.ncsn.model.ema: raise RuntimeError( "Cannot apply ensembling to pretrained with EMA.") self.fast_ensemble_fid() return from ncsn.evaluation.fid_score import get_fid, get_fid_stats_path import pickle source_dataset, _ = get_dataset(self.args, self.config.source) source_dataloader = DataLoader( source_dataset, batch_size=self.config.ncsn.sampling.sources_per_batch, shuffle=True, num_workers=self.config.source.data.num_workers) source_iter = iter(source_dataloader) score = get_scorenet(self.config.ncsn) score = torch.nn.DataParallel(score) if self.config.compatibility.ckpt_id is None: cpat_states = torch.load(os.path.join( 'scones', self.config.compatibility.log_path, 'checkpoint.pth'), map_location=self.config.device) else: cpat_states = torch.load(os.path.join( 'scones', self.config.compatibility.log_path, f'checkpoint_{self.config.compatibility.ckpt_id}.pth'), map_location=self.config.device) cpat = get_compatibility(self.config) cpat.load_state_dict(cpat_states[0]) sigmas_th = get_sigmas(self.config.ncsn) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm(range(self.config.ncsn.fast_fid.begin_ckpt, self.config.ncsn.fast_fid.end_ckpt + 1, 5000), desc="processing ckpt"): states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{ckpt}.pth'), map_location=self.config.device) if self.config.ncsn.model.ema: ema_helper = EMAHelper(mu=self.config.ncsn.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() num_iters = self.config.ncsn.fast_fid.num_samples // self.config.ncsn.fast_fid.batch_size output_path = os.path.join(self.args.image_folder, 'ckpt_{}'.format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): try: (Xs, _) = next(source_iter) Xs_global = torch.cat( [Xs] * self.config.ncsn.sampling.samples_per_source, dim=0).to(self.config.device) except StopIteration: source_iter = iter(source_dataloader) (Xs, _) = next(source_iter) Xs_global = torch.cat( [Xs] * self.config.ncsn.sampling.samples_per_source, dim=0).to(self.config.device) init_samples = torch.rand(self.config.ncsn.fast_fid.batch_size, self.config.target.data.channels, self.config.target.data.image_size, self.config.target.data.image_size, device=self.config.device) init_samples = data_transform(self.config.target, init_samples) init_samples.requires_grad = True init_samples = init_samples.to(self.config.device) all_samples = anneal_Langevin_dynamics( init_samples, Xs_global, score, cpat, sigmas, self.config.ncsn.fast_fid.n_steps_each, self.config.ncsn.fast_fid.step_lr, verbose=self.config.ncsn.fast_fid.verbose, final_only=self.config.ncsn.sampling.final_only, denoise=self.config.ncsn.sampling.denoise) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view(self.config.target.data.channels, self.config.target.data.image_size, self.config.target.data.image_size) sample = inverse_data_transform(self.config.target, sample) save_image( sample, os.path.join(output_path, 'sample_{}.png'.format(id))) stat_path = get_fid_stats_path(self.args, self.config.ncsn, download=True) fid = get_fid(stat_path, output_path) fids[ckpt] = fid print("ckpt: {}, fid: {}".format(ckpt, fid)) with open(os.path.join(self.args.image_folder, 'fids.pickle'), 'wb') as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)