def run(self): bs = self.config.sampling.batch_size dataloader = self.get_dataloader(bs=bs) sigmas = self.get_sigmas(npy=True) score = self.get_model() final_samples_denoised = None kwargs = {'sigmas': sigmas, 'nsigma': self.config.sampling.nsigma, 'step_lr': self.config.sampling.step_lr, 'final_only': True, 'target': self.args.target, 'noise_first': self.config.sampling.noise_first} output_path = self.args.image_folder output_path_denoised = self.args.image_folder_denoised os.makedirs(output_path, exist_ok=True) os.makedirs(output_path_denoised, exist_ok=True) os.makedirs(self.args.fid_folder, exist_ok=True) for ckpt in tqdm.tqdm(range(self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, self.config.training.snapshot_freq), desc="processing ckpt"): score = self._load_states(score) score.eval() kwargs['scorenet'] = score for k in range(self.config.fast_fid.num_samples // bs): final_samples, final_samples_denoised = self.sample(dataloader, saveimages=(k == 0), kwargs=kwargs, bs=bs, gridsize=100, ckpt_id=ckpt) sizes = [self.config.data.channels, self.config.data.image_size, self.config.data.image_size] for i, sample in enumerate(final_samples[0]): sample = inverse_data_transform(self.config.data, sample.view(*sizes)) save_image(sample, os.path.join(output_path, 'sample_{}.png'.format(i + k * bs))) if final_samples_denoised is not None: for i, sample in enumerate(final_samples_denoised[0]): sample = inverse_data_transform(self.config.data, sample.view(*sizes)) save_image(sample, os.path.join(output_path_denoised, 'sample_{}.png'.format(i + k * bs))) log_output = open(f"{self.args.fid_folder}/log_FID.txt", 'a+') stat_path = get_fid_stats_path(self.config.data, fid_stats_folder=self.args.exp) fid = get_fid(stat_path, output_path, bs=self.config.fast_fid.batch_size) print("(Samples) {} ckpt: {}, fid: {}".format(self.args.doc, ckpt, fid)) print("(Samples) {} ckpt: {}, fid: {}".format(self.args.doc, ckpt, fid), file=log_output) if final_samples_denoised is not None: fid_denoised = get_fid(stat_path, output_path_denoised, bs=self.config.fast_fid.batch_size) print("(Denoised samples) {} ckpt: {}, fid: {}".format(self.args.doc, ckpt, fid_denoised)) print("(Denoised samples) {} ckpt: {}, fid: {}".format(self.args.doc, ckpt, fid_denoised), file=log_output)
def fast_ensemble_fid(self): from evaluation.fid_score import get_fid, get_fid_stats_path import pickle num_ensembles = 5 scores = [NCSN(self.config).to(self.config.device) for _ in range(num_ensembles)] scores = [torch.nn.DataParallel(score) for score in scores] sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm(range(self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000), desc="processing ckpt"): begin_ckpt = max(self.config.fast_fid.begin_ckpt, ckpt - (num_ensembles - 1) * 5000) index = 0 for i in range(begin_ckpt, ckpt + 5000, 5000): states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{i}.pth'), map_location=self.config.device) scores[index].load_state_dict(states[0]) scores[index].eval() index += 1 def scorenet(x, labels): num_ckpts = (ckpt - begin_ckpt) // 5000 + 1 return sum([scores[i](x, labels) for i in range(num_ckpts)]) / num_ckpts num_iters = self.config.fast_fid.num_samples // self.config.fast_fid.batch_size output_path = os.path.join(self.args.image_folder, 'ckpt_{}'.format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): init_samples = torch.rand(self.config.fast_fid.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics(init_samples, scorenet, sigmas, self.config.fast_fid.n_steps_each, self.config.fast_fid.step_lr, verbose=self.config.fast_fid.verbose, denoise=self.config.sampling.denoise) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view(self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) save_image(sample, os.path.join(output_path, 'sample_{}.png'.format(id))) stat_path = get_fid_stats_path(self.args, self.config, download=True) fid = get_fid(stat_path, output_path) fids[ckpt] = fid print("ckpt: {}, fid: {}".format(ckpt, fid)) with open(os.path.join(self.args.image_folder, 'fids.pickle'), 'wb') as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)
def fast_fid(self): ### Test the fids of ensembled checkpoints. ### Shouldn't be used for models with ema if self.config.fast_fid.ensemble: if self.config.model.ema: raise RuntimeError("Cannot apply ensembling to models with EMA.") self.fast_ensemble_fid() return from evaluation.fid_score import get_fid, get_fid_stats_path import pickle score = get_model(self.config) score = torch.nn.DataParallel(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm(range(self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000), desc="processing ckpt"): states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{ckpt}.pth'), map_location=self.config.device) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() num_iters = self.config.fast_fid.num_samples // self.config.fast_fid.batch_size output_path = os.path.join(self.args.image_folder, 'ckpt_{}'.format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): init_samples = torch.rand(self.config.fast_fid.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics(init_samples, score, sigmas, self.config.fast_fid.n_steps_each, self.config.fast_fid.step_lr, verbose=self.config.fast_fid.verbose, denoise=self.config.sampling.denoise) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view(self.config.data.channels, self.config.data.image_size, self.config.data.image_size) sample = inverse_data_transform(self.config, sample) save_image(sample, os.path.join(output_path, 'sample_{}.png'.format(id))) stat_path = get_fid_stats_path(self.args, self.config, download=True) fid = get_fid(stat_path, output_path) fids[ckpt] = fid print("ckpt: {}, fid: {}".format(ckpt, fid)) with open(os.path.join(self.args.image_folder, 'fids.pickle'), 'wb') as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)