def train(self): dataset, test_dataset = get_dataset(self.args, self.config) dataloader = DataLoader( dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=self.config.data.num_workers, ) test_loader = DataLoader( test_dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=self.config.data.num_workers, drop_last=True, ) test_iter = iter(test_loader) self.config.input_dim = ( self.config.data.image_size ** 2 * self.config.data.channels ) tb_logger = self.config.tb_logger score = get_model(self.config) score = torch.nn.DataParallel(score) optimizer = get_optimizer(self.config, score.parameters()) start_epoch = 0 step = 0 if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) if self.args.resume_training: states = torch.load(os.path.join(self.args.log_path, "checkpoint.pth")) score.load_state_dict(states[0]) ### Make sure we can resume with different eps states[1]["param_groups"][0]["eps"] = self.config.optim.eps optimizer.load_state_dict(states[1]) start_epoch = states[2] step = states[3] if self.config.model.ema: ema_helper.load_state_dict(states[4]) sigmas = get_sigmas(self.config) if self.config.training.log_all_sigmas: ### Commented out training time logging to save time. test_loss_per_sigma = [None for _ in range(len(sigmas))] def hook(loss, labels): # for i in range(len(sigmas)): # if torch.any(labels == i): # test_loss_per_sigma[i] = torch.mean(loss[labels == i]) pass def tb_hook(): # for i in range(len(sigmas)): # if test_loss_per_sigma[i] is not None: # tb_logger.add_scalar('test_loss_sigma_{}'.format(i), test_loss_per_sigma[i], # global_step=step) pass def test_hook(loss, labels): for i in range(len(sigmas)): if torch.any(labels == i): test_loss_per_sigma[i] = torch.mean(loss[labels == i]) def test_tb_hook(): for i in range(len(sigmas)): if test_loss_per_sigma[i] is not None: tb_logger.add_scalar( "test_loss_sigma_{}".format(i), test_loss_per_sigma[i], global_step=step, ) else: hook = test_hook = None def tb_hook(): pass def test_tb_hook(): pass for epoch in range(start_epoch, self.config.training.n_epochs): for i, (X, y) in enumerate(dataloader): score.train() step += 1 X = X.to(self.config.device) X = data_transform(self.config, X) loss = anneal_sliced_score_estimation_vr( score, X, sigmas, None, self.config.training.anneal_power, hook ) tb_logger.add_scalar("loss", loss, global_step=step) tb_hook() logging.info("step: {}, loss: {}".format(step, loss.item())) optimizer.zero_grad() loss.backward() optimizer.step() if self.config.model.ema: ema_helper.update(score) if step >= self.config.training.n_iters: return 0 if step % 100 == 0: if self.config.model.ema: test_score = ema_helper.ema_copy(score) else: test_score = score test_score.eval() try: test_X, test_y = next(test_iter) except StopIteration: test_iter = iter(test_loader) test_X, test_y = next(test_iter) test_X = test_X.to(self.config.device) test_X = data_transform(self.config, test_X) test_dsm_loss = anneal_sliced_score_estimation_vr( test_score, test_X, sigmas, None, self.config.training.anneal_power, hook=test_hook, ) tb_logger.add_scalar("test_loss", test_dsm_loss, global_step=step) test_tb_hook() logging.info( "step: {}, test_loss: {}".format(step, test_dsm_loss.item()) ) del test_score if step % self.config.training.snapshot_freq == 0: states = [ score.state_dict(), optimizer.state_dict(), epoch, step, ] if self.config.model.ema: states.append(ema_helper.state_dict()) torch.save( states, os.path.join( self.args.log_path, "checkpoint_{}.pth".format(step) ), ) torch.save( states, os.path.join(self.args.log_path, "checkpoint.pth") ) if self.config.training.snapshot_sampling: if self.config.model.ema: test_score = ema_helper.ema_copy(score) else: test_score = score test_score.eval() ## Different part from NeurIPS 2019. ## Random state will be affected because of sampling during training time. init_samples = torch.rand( 36, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device, ) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics( init_samples, test_score, sigmas.cpu().numpy(), self.config.sampling.n_steps_each, self.config.sampling.step_lr, final_only=True, verbose=True, ) sample = all_samples[-1].view( all_samples[-1].shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) image_grid = make_grid(sample, 6) save_image( image_grid, os.path.join( self.args.log_sample_path, "image_grid_{}.png".format(step), ), ) torch.save( sample, os.path.join( self.args.log_sample_path, "samples_{}.pth".format(step) ), ) del test_score del all_samples
def sample(self): if self.config.sampling.ckpt_id is None: states = torch.load( os.path.join(self.args.log_path, "checkpoint.pth"), map_location=self.config.device, ) else: states = torch.load( os.path.join( self.args.log_path, f"checkpoint_{self.config.sampling.ckpt_id}.pth" ), map_location=self.config.device, ) score = get_model(self.config) score = torch.nn.DataParallel(score) score.load_state_dict(states[0], strict=True) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() dataset, _ = get_dataset(self.args, self.config) dataloader = DataLoader( dataset, batch_size=self.config.sampling.batch_size, shuffle=True, num_workers=4, ) score.eval() if not self.config.sampling.fid: if self.config.sampling.inpainting: data_iter = iter(dataloader) refer_images, _ = next(data_iter) refer_images = refer_images.to(self.config.device) width = int(np.sqrt(self.config.sampling.batch_size)) init_samples = torch.rand( width, width, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device, ) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics_inpainting( init_samples, refer_images[:width, ...], score, sigmas, self.config.data.image_size, self.config.sampling.n_steps_each, self.config.sampling.step_lr, ) torch.save( refer_images[:width, ...], os.path.join(self.args.image_folder, "refer_image.pth"), ) refer_images = ( refer_images[:width, None, ...] .expand(-1, width, -1, -1, -1) .reshape(-1, *refer_images.shape[1:]) ) save_image( refer_images, os.path.join(self.args.image_folder, "refer_image.png"), nrow=width, ) if not self.config.sampling.final_only: for i, sample in enumerate(tqdm.tqdm(all_samples)): sample = sample.view( self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, int(np.sqrt(self.config.sampling.batch_size)) ) save_image( image_grid, os.path.join( self.args.image_folder, "image_grid_{}.png".format(i) ), ) torch.save( sample, os.path.join( self.args.image_folder, "completion_{}.pth".format(i) ), ) else: sample = all_samples[-1].view( self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, int(np.sqrt(self.config.sampling.batch_size)) ) save_image( image_grid, os.path.join( self.args.image_folder, "image_grid_{}.png".format(self.config.sampling.ckpt_id), ), ) torch.save( sample, os.path.join( self.args.image_folder, "completion_{}.pth".format(self.config.sampling.ckpt_id), ), ) elif self.config.sampling.interpolation: if self.config.sampling.data_init: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) init_samples = samples + sigmas_th[0] * torch.randn_like(samples) else: init_samples = torch.rand( self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device, ) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics_interpolation( init_samples, score, sigmas, self.config.sampling.n_interpolations, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=True, final_only=self.config.sampling.final_only, ) if not self.config.sampling.final_only: for i, sample in tqdm.tqdm( enumerate(all_samples), total=len(all_samples), desc="saving image samples", ): sample = sample.view( sample.shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, nrow=self.config.sampling.n_interpolations ) save_image( image_grid, os.path.join( self.args.image_folder, "image_grid_{}.png".format(i) ), ) torch.save( sample, os.path.join( self.args.image_folder, "samples_{}.pth".format(i) ), ) else: sample = all_samples[-1].view( all_samples[-1].shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, self.config.sampling.n_interpolations ) save_image( image_grid, os.path.join( self.args.image_folder, "image_grid_{}.png".format(self.config.sampling.ckpt_id), ), ) torch.save( sample, os.path.join( self.args.image_folder, "samples_{}.pth".format(self.config.sampling.ckpt_id), ), ) else: if self.config.sampling.data_init: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) init_samples = samples + sigmas_th[0] * torch.randn_like(samples) else: init_samples = torch.rand( self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device, ) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics( init_samples, score, sigmas, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=True, final_only=self.config.sampling.final_only, ) if not self.config.sampling.final_only: for i, sample in tqdm.tqdm( enumerate(all_samples), total=len(all_samples), desc="saving image samples", ): sample = sample.view( sample.shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, int(np.sqrt(self.config.sampling.batch_size)) ) save_image( image_grid, os.path.join( self.args.image_folder, "image_grid_{}.png".format(i) ), ) torch.save( sample, os.path.join( self.args.image_folder, "samples_{}.pth".format(i) ), ) else: sample = all_samples[-1].view( all_samples[-1].shape[0], self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) image_grid = make_grid( sample, int(np.sqrt(self.config.sampling.batch_size)) ) save_image( image_grid, os.path.join( self.args.image_folder, "image_grid_{}.png".format(self.config.sampling.ckpt_id), ), ) torch.save( sample, os.path.join( self.args.image_folder, "samples_{}.pth".format(self.config.sampling.ckpt_id), ), ) else: total_n_samples = self.config.sampling.num_samples4fid n_rounds = total_n_samples // self.config.sampling.batch_size if self.config.sampling.data_init: dataloader = DataLoader( dataset, batch_size=self.config.sampling.batch_size, shuffle=True, num_workers=4, ) data_iter = iter(dataloader) img_id = 0 for _ in tqdm.tqdm( range(n_rounds), desc="Generating image samples for FID/inception score evaluation", ): if self.config.sampling.data_init: try: samples, _ = next(data_iter) except StopIteration: data_iter = iter(dataloader) samples, _ = next(data_iter) samples = samples.to(self.config.device) samples = data_transform(self.config, samples) samples = samples + sigmas_th[0] * torch.randn_like(samples) else: samples = torch.rand( self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device, ) samples = data_transform(self.config, samples) all_samples = anneal_Langevin_dynamics( samples, score, sigmas, self.config.sampling.n_steps_each, self.config.sampling.step_lr, verbose=False, ) samples = all_samples[-1] for img in samples: img = inverse_data_transform(self.config, img) save_image( img, os.path.join( self.args.image_folder, "image_{}.png".format(img_id) ), ) img_id += 1
def test(self): score = get_model(self.config) score = torch.nn.DataParallel(score) sigmas = get_sigmas(self.config) dataset, test_dataset = get_dataset(self.args, self.config) test_dataloader = DataLoader( test_dataset, batch_size=self.config.test.batch_size, shuffle=True, num_workers=self.config.data.num_workers, drop_last=True, ) verbose = False for ckpt in tqdm.tqdm( range(self.config.test.begin_ckpt, self.config.test.end_ckpt + 1, 5000), desc="processing ckpt:", ): states = torch.load( os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"), map_location=self.config.device, ) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() step = 0 mean_loss = 0.0 mean_grad_norm = 0.0 average_grad_scale = 0.0 for x, y in test_dataloader: step += 1 x = x.to(self.config.device) x = data_transform(self.config, x) with torch.no_grad(): test_loss = anneal_ESM_scorenet_VR( score, x, sigmas, None, self.config.training.anneal_power) if verbose: logging.info("step: {}, test_loss: {}".format( step, test_loss.item())) mean_loss += test_loss.item() mean_loss /= step mean_grad_norm /= step average_grad_scale /= step logging.info("ckpt: {}, average test loss: {}".format( ckpt, mean_loss))
def fast_fid(self): ### Test the fids of ensembled checkpoints. ### Shouldn't be used for models with ema if self.config.fast_fid.ensemble: if self.config.model.ema: raise RuntimeError( "Cannot apply ensembling to models with EMA.") self.fast_ensemble_fid() return from evaluation.fid_score import get_fid, get_fid_stats_path import pickle score = get_model(self.config) score = torch.nn.DataParallel(score) sigmas_th = get_sigmas(self.config) sigmas = sigmas_th.cpu().numpy() fids = {} for ckpt in tqdm.tqdm( range(self.config.fast_fid.begin_ckpt, self.config.fast_fid.end_ckpt + 1, 5000), desc="processing ckpt", ): states = torch.load( os.path.join(self.args.log_path, f"checkpoint_{ckpt}.pth"), map_location=self.config.device, ) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(score) ema_helper.load_state_dict(states[-1]) ema_helper.ema(score) else: score.load_state_dict(states[0]) score.eval() num_iters = (self.config.fast_fid.num_samples // self.config.fast_fid.batch_size) output_path = os.path.join(self.args.image_folder, "ckpt_{}".format(ckpt)) os.makedirs(output_path, exist_ok=True) for i in range(num_iters): init_samples = torch.rand( self.config.fast_fid.batch_size, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, device=self.config.device, ) init_samples = data_transform(self.config, init_samples) all_samples = anneal_Langevin_dynamics( init_samples, score, sigmas, self.config.fast_fid.n_steps_each, self.config.fast_fid.step_lr, verbose=self.config.fast_fid.verbose, ) final_samples = all_samples[-1] for id, sample in enumerate(final_samples): sample = sample.view( self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) sample = inverse_data_transform(self.config, sample) save_image( sample, os.path.join(output_path, "sample_{}.png".format(id))) stat_path = get_fid_stats_path(self.args, self.config, download=True) fid = get_fid(stat_path, output_path) fids[ckpt] = fid print("ckpt: {}, fid: {}".format(ckpt, fid)) with open(os.path.join(self.args.image_folder, "fids.pickle"), "wb") as handle: pickle.dump(fids, handle, protocol=pickle.HIGHEST_PROTOCOL)
def train(self): obs = (1, 28, 28) if 'MNIST' in self.config.dataset else (3, 32, 32) input_channels = obs[0] train_loader, test_loader = dataset.get_dataset(self.config) model = PixelCNN(self.config) model = model.to(self.config.device) model = torch.nn.DataParallel(model) sample_model = partial(model, sample=True) rescaling_inv = lambda x: .5 * x + .5 rescaling = lambda x: (x - .5) * 2. if 'MNIST' in self.config.dataset: loss_op = lambda real, fake: mix_logistic_loss_1d(real, fake) clamp = False sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, sample_model, self.config.nr_logistic_mix, clamp=clamp) elif 'CIFAR10' in self.config.dataset: loss_op = lambda real, fake: mix_logistic_loss(real, fake) clamp = False sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix, clamp=clamp) elif 'celeba' in self.config.dataset: loss_op = lambda real, fake: mix_logistic_loss(real, fake) clamp = False sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix, clamp=clamp) else: raise Exception('{} dataset not in {mnist, cifar10, celeba}'.format(self.config.dataset)) if self.config.model.ema: ema_helper = EMAHelper(mu=self.config.model.ema_rate) ema_helper.register(model) else: ema_helper = None optimizer = optim.Adam(model.parameters(), lr=self.config.lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.config.lr_decay) ckpt_path = os.path.join(self.args.log, 'pixelcnn_ckpts') if not os.path.exists(ckpt_path): os.makedirs(ckpt_path) if self.args.resume_training: state_dict = torch.load(os.path.join(ckpt_path, 'checkpoint.pth'), map_location=self.config.device) model.load_state_dict(state_dict[0]) optimizer.load_state_dict(state_dict[1]) scheduler.load_state_dict(state_dict[2]) if len(state_dict) > 3: epoch = state_dict[3] if self.config.model.ema: ema_helper.load_state_dict(states[4]) print('model parameters loaded') tb_path = os.path.join(self.args.log, 'tensorboard') if os.path.exists(tb_path): shutil.rmtree(tb_path) os.makedirs(tb_path) tb_logger = SummaryWriter(log_dir=tb_path) def debug_sample(model, data): model.eval() data = data.cuda() with torch.no_grad(): for i in range(obs[1]): for j in range(obs[2]): data_v = data out_sample = sample_op(data_v) data[:, :, i, j] = out_sample.data[:, :, i, j] return data print('starting training', flush=True) writes = 0 for epoch in range(self.config.max_epochs): train_loss = 0. model.train() for batch_idx, (input, _) in enumerate(train_loader): input = input.cuda(non_blocking=True) # input: [-1, 1] ## add noise to the entire image input = input + torch.randn_like(input) * self.config.noise output = model(input) loss = loss_op(input, output) optimizer.zero_grad() loss.backward() optimizer.step() if self.config.model.ema: ema_helper.update(model) train_loss += loss.item() if (batch_idx + 1) % self.config.print_every == 0: deno = self.config.print_every * self.config.batch_size * np.prod(obs) * np.log(2.) train_loss = train_loss / deno print('epoch: {}, batch: {}, loss : {:.4f}'.format(epoch, batch_idx, train_loss), flush=True) tb_logger.add_scalar('loss', train_loss, global_step=writes) train_loss = 0. writes += 1 # decrease learning rate scheduler.step() if self.config.model.ema: test_model = ema_helper.ema_copy(model) else: test_model = model test_model.eval() test_loss = 0. with torch.no_grad(): for batch_idx, (input_var, _) in enumerate(test_loader): input_var = input_var.cuda(non_blocking=True) input_var = input_var + torch.randn_like(input_var) * self.config.noise output = test_model(input_var) loss = loss_op(input_var, output) test_loss += loss.item() del loss, output deno = batch_idx * self.config.batch_size * np.prod(obs) * np.log(2.) test_loss = test_loss / deno print('epoch: %s, test loss : %s' % (epoch, test_loss), flush=True) tb_logger.add_scalar('test_loss', test_loss, global_step=writes) if (epoch + 1) % self.config.save_interval == 0: state_dict = [ model.state_dict(), optimizer.state_dict(), scheduler.state_dict(), epoch, ] if self.config.model.ema: state_dict.append(ema_helper.state_dict()) if (epoch + 1) % (self.config.save_interval * 2) == 0: torch.save(state_dict, os.path.join(ckpt_path, f'ckpt_epoch_{epoch}.pth')) torch.save(state_dict, os.path.join(ckpt_path, 'checkpoint.pth')) if epoch % 10 == 0: print('sampling...', flush=True) sample_t = debug_sample(test_model, input_var[:25]) sample_t = rescaling_inv(sample_t) if not os.path.exists(os.path.join(self.args.log, 'images')): os.makedirs(os.path.join(self.args.log, 'images')) utils.save_image(sample_t, os.path.join(self.args.log, 'images', f'sample_epoch_{epoch}.png'), nrow=5, padding=0) if self.config.model.ema: del test_model