def save_test_sample(real_imgs_lab, fake_imgs_lab1, fake_imgs_lab2, save_path, plot_size=14, scale=1.6, show=False): """ Create a grid of ground truth, grayscale and 2 colorized images (from different sources) and save + display it to the user. """ batch_size = real_imgs_lab.size()[0] plot_size = min(plot_size, batch_size) # create white canvas canvas = np.ones((plot_size * 32 + (plot_size + 1) * 6, 4 * 32 + 5 * 8, 3), dtype=np.uint8) * 255 real_imgs_lab = real_imgs_lab.cpu().numpy() fake_imgs_lab1 = fake_imgs_lab1.cpu().numpy() fake_imgs_lab2 = fake_imgs_lab2.cpu().numpy() for i in range(0, plot_size): # post-process real and fake samples real_bgr = postprocess(real_imgs_lab[i]) fake_bgr1 = postprocess(fake_imgs_lab1[i]) fake_bgr2 = postprocess(fake_imgs_lab2[i]) grayscale = np.expand_dims( cv2.cvtColor(real_bgr.astype(np.float32), cv2.COLOR_BGR2GRAY), 2) # paint x = (i + 1) * 6 + i * 32 canvas[x:x + 32, 8:40, :] = real_bgr canvas[x:x + 32, 48:80, :] = np.repeat(grayscale, 3, axis=2) canvas[x:x + 32, 88:120, :] = fake_bgr1 canvas[x:x + 32, 128:160, :] = fake_bgr2 # scale canvas = cv2.resize(canvas, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST) # save cv2.imwrite(os.path.join(save_path), canvas) if show: cv2.destroyAllWindows() cv2.imshow('sample', canvas) cv2.waitKey(10000)
def plot_samples(x, fpath): grid = make_grid((postprocess(x.detach().cpu())[:30]), nrow=6).permute(1,2,0) plt.figure(figsize=(5,5)) plt.imshow(grid) plt.axis('off') plt.tight_layout() plt.savefig(fpath, bbox_inches='tight', pad_inches=0)
def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits, im = model(x) losses = compute_loss(nll) if engine.state.iteration % 250 == 1: vis.line(X=np.array([engine.state.iteration]), Y=np.array([losses["total_loss"].item()]), win=train_loss_window, update='append', env=env) vis.images(postprocess(im), nrow=16, win=train_image_window, env=env) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses
def plot_imgs(imgs, title): K = len(imgs) f, axs = plt.subplots(1,K,figsize=(K*5,4)) for idx, images in enumerate(imgs): grid = make_grid((postprocess(images).cpu().detach()[:30]), nrow=6).permute(1,2,0) axs[idx].imshow(grid) axs[idx].axis('off') plt.savefig(os.path.join(output_folder, f'{title}.png'))
def sample(model): with torch.no_grad(): y = None images = postprocess(model(y_onehot=y, temperature=1, reverse=True)) return images.cpu()
def save_sample(real_imgs_lab, fake_imgs_lab, save_path, plot_size=20, scale=2.2, show=False): """Create a grid of ground truth, grayscale and colorized images and save + display it to the user.""" batch_size = real_imgs_lab.size()[0] plot_size = min(plot_size, batch_size) # create white canvas canvas = np.ones((3 * 32 + 4 * 6, plot_size * 32 + (plot_size + 1) * 6, 3), dtype=np.uint8) * 255 real_imgs_lab = real_imgs_lab.cpu().numpy() fake_imgs_lab = fake_imgs_lab.cpu().numpy() for i in range(0, plot_size): # postprocess real and fake samples real_bgr = postprocess(real_imgs_lab[i]) fake_bgr = postprocess(fake_imgs_lab[i]) grayscale = np.expand_dims( cv2.cvtColor(real_bgr.astype(np.float32), cv2.COLOR_BGR2GRAY), 2) # paint x = (i + 1) * 6 + i * 32 canvas[6:38, x:x + 32, :] = real_bgr canvas[44:76, x:x + 32, :] = np.repeat(grayscale, 3, axis=2) canvas[82:114, x:x + 32, :] = fake_bgr # scale canvas = cv2.resize(canvas, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST) # save cv2.imwrite(os.path.join(save_path), canvas) if show: cv2.destroyAllWindows() cv2.imshow('sample', canvas) cv2.waitKey(10000)
def sample(model): with torch.no_grad(): if hparams['y_condition']: y = torch.eye(num_classes) y = y.repeat(batch_size // num_classes + 1) y = y[:32, :].to(device) # number hardcoded in model for now else: y = None images = postprocess(model(y_onehot=y, temperature=1, reverse=True)) return images.cpu()
def save_inverse_images(self, x, z): print ("Start sample inverse") start = time.time() assert x.size(0) == z.size(0), "sizes are not the consistent" x = postprocess(x, self.n_bits) img,_ = self.model(z=z, eps_std=1.0, reverse=True) img = img.detach().cpu() x = x.detach().cpu() img = postprocess(img, self.n_bits) output = None for i in range(0, min(self.batch_size, 10)): row = torch.cat((x[i], img[i]), dim=1) if output is None: output = row else: output = torch.cat((output, row), dim=2) save_image(output, os.path.join(self.save_dir, "img-{}.jpg".format(self.global_step))) print ("End sample inverse") print ("Elapsed time: {:.5f}".format(time.time() - start))
def Sample(self,n_samples=64, sample_each_row=8, eps_std=1.0): assert n_samples % sample_each_row == 0, "cannot arrange the samples" i = 0 row_id = 0 while i < n_samples: print ("sample: {}\{}".format(i, n_samples)) row = None for r in range(0, sample_each_row): s,_ = self.model(z=None, eps_std=eps_std, reverse=True) s = postprocess(s, n_bits=self.n_bits) i = i + 1 if row is None: row = s else: row = torch.cat((row, s), dim=3) save_image(row, os.path.join(self.sample_root, "sample-{}.png".format(row_id))) row_id = row_id + 1
def save_sample_images(self,n_samples=20, sample_each_row=5, eps_std=1.0): print ("Start sampling") start = time.time() assert n_samples % sample_each_row == 0, "cannot arrange the samples" samples = [] for i in range(0, n_samples): s,_ = self.model(z=None, eps_std=eps_std, reverse=True) s = s.detach().cpu() s = postprocess(s, self.n_bits) samples.append(s) n_rows = int(n_samples / sample_each_row) i = 0 output = None for r in range(0, n_rows): row = None for s in range(0, sample_each_row): if row is None: row = samples[i] i = i + 1 continue else: row = torch.cat((row, samples[i]), dim=2) i = i + 1 if output is None: output = row continue else: output = torch.cat((output, row), dim=3) save_image(output, os.path.join(self.save_dir, "sample-{}.jpg".format(self.global_step))) print("End sampling") print("Elapsed time:{:.2f}".format(time.time() - start))
y = torch.eye(num_classes) y = y.repeat(batch_size // num_classes + 1) y = y[:32, :].to(device) # number hardcoded in model for now else: y = None images = model(y_onehot=y, temperature=1, reverse=True) # images = postprocess(model(y_onehot=y, temperature=1, reverse=True)) return images.cpu() batch_size = 32 images = [] N = 10000 iters = N // batch_size + 1 for _ in range(iters): images.append(sample(model)) images = torch.stack(images) images = postprocess(images[:N]) torch.save(images, '10k_samples.pt') ipdb.set_trace() # images = sample(model) # grid = make_grid(postprocess(images[:30]), nrow=6).permute(1,2,0) # plt.figure(figsize=(10,10)) # plt.imshow(grid) # plt.savefig('sample.png') # plt.axis('off')
def main(dataset, dataroot, download, augment, n_workers, eval_batch_size, output_dir, db, glow_path, ckpt_name, new_data): model = torch.load(glow_path) model = model.to(device) model.eval() if new_data: (image_shape, num_classes, train_dataset, test_dataset) = check_dataset(dataset, dataroot, augment, download) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) x = test_loader.__iter__().__next__()[0].to(device) # OOD data ood_distributions = ['gaussian', 'rademacher', 'svhn'] tr = transforms.Compose([]) tr.transforms.append(transforms.ToPILImage()) tr.transforms.append(transforms.Resize((32, 32))) tr.transforms.append(transforms.ToTensor()) tr.transforms.append(one_to_three_channels) tr.transforms.append(preprocess) ood_tensors = [(out_name, torch.stack([ tr(x) for x in load_ood_data({ 'name': out_name, 'ood_scale': 1, 'n_anom': eval_batch_size, }) ]).to(device)) for out_name in ood_distributions] # Get fixed `z` for samples _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (eval_batch_size, c, h, w) zs = torch.randn(zshape).to(device) all_data = [('data', x), ('samples', zs)] + ood_tensors pickle.dump( all_data, open( os.path.join(os.environ['ROOT1'], 'data/flow-analysis-data.pkl'), 'wb')) else: all_data = pickle.load( open( os.path.join(os.environ['ROOT1'], 'data/flow-analysis-data.pkl'), 'rb')) f, axs = plt.subplots(2, len(all_data), figsize=(len(all_data) * 3, 6)) # Plot Data for (name, x), ax in zip(all_data, axs[0]): if name == 'samples': with torch.no_grad(): x = model(z=x, y_onehot=None, temperature=1, reverse=True, batch_size=0) plt.subplot(ax) grid = make_grid((postprocess(x.cpu(), "")[:16]), nrow=4).permute(1, 2, 0) plt.imshow(grid) plt.xticks([]) plt.yticks([]) plt.title(f"{name}", fontsize=18) # Plot Recon for (name, x), ax in zip(all_data, axs[1]): if name == 'samples': with torch.no_grad(): x = model(z=x, y_onehot=None, temperature=1, reverse=True, batch_size=0) with torch.no_grad(): x = run_recon(x, model) plt.subplot(ax) grid = make_grid((postprocess(x.cpu(), "")[:16]), nrow=4).permute(1, 2, 0) plt.imshow(grid) plt.xticks([]) plt.yticks([]) plt.title(f"{name}", fontsize=18) plt.suptitle("Top: Input, Bottom: Recon") plt.savefig(os.path.join(output_dir, f'all_data_recon_{ckpt_name}.jpeg'), bbox_inches='tight') # stats = OrderedDict() for name, x in all_data: if name == 'samples': with torch.no_grad(): x = model(z=x, y_onehot=None, temperature=1, reverse=True, batch_size=0) p_pxs, p_ims, cn, dlogdet, bpd, pad, l2_0, l2_9 = run_analysis( x, model, os.path.join(output_dir, f'recon_{ckpt_name}_{name}.jpeg')) stats[f"{name}-percent-pixels-nans"] = p_pxs stats[f"{name}-percent-imgs-nans"] = p_ims stats[f"{name}-cn"] = cn stats[f"{name}-dlogdet"] = dlogdet stats[f"{name}-bpd"] = bpd stats[f"{name}-pad"] = pad stats[f"{name}-l2_0"] = l2_0 stats[f"{name}-l2_9"] = l2_9 with open(os.path.join(output_dir, f'results_{ckpt_name}.json'), 'w') as fp: json.dump(stats, fp, indent=4)
def main(args): # torch.manual_seed(args.seed) # Test loading and sampling output_folder = os.path.join('results', args.name) with open(os.path.join(output_folder, 'hparams.json')) as json_file: hparams = json.load(json_file) device = "cpu" if not torch.cuda.is_available() else "cuda:0" image_shape = (hparams['patch_size'], hparams['patch_size'], args.n_modalities) num_classes = 1 print('Loading model...') model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, hparams['learn_top'], hparams['y_condition']) model_chkpt = torch.load( os.path.join(output_folder, 'checkpoints', args.model)) model.load_state_dict(model_chkpt['model']) model.set_actnorm_init() model = model.to(device) # Build images model.eval() temperature = args.temperature if args.steps is None: # automatically calculate step size if no step size fig_dir = os.path.join(output_folder, 'stepnum_results') if not os.path.exists(fig_dir): os.mkdir(fig_dir) print('No step size entered') # Create sample of images to estimate chord length with torch.no_grad(): mean, logs = model.prior(None, None) z = gaussian_sample(mean, logs, temperature) images_raw = model(z=z, temperature=temperature, reverse=True) images_raw[torch.isnan(images_raw)] = 0.5 images_raw[torch.isinf(images_raw)] = 0.5 images_raw = torch.clamp(images_raw, -0.5, 0.5) images_out = np.transpose( np.squeeze(images_raw[:, args.step_modality, :, :].cpu().numpy()), (1, 0, 2)) # Threshold images and compute covariances if args.binary_data: thresh = 0 else: thresh = threshold_otsu(images_out) images_bin = np.greater(images_out, thresh) x_cov = two_point_correlation(images_bin, 0) y_cov = two_point_correlation(images_bin, 1) # Compute chord length cov_avg = np.mean(np.mean(np.concatenate((x_cov, y_cov), axis=2), axis=0), axis=0) N = 5 S20, _ = curve_fit(straight_line_at_origin(cov_avg[0]), range(0, N), cov_avg[0:N]) l_pore = np.abs(cov_avg[0] / S20) steps = int(l_pore) print('Calculated step size: {}'.format(steps)) else: print('Using user-entered step size {}...'.format(args.steps)) steps = args.steps # Build desired number of volumes for iter_vol in range(args.iter): if args.iter == 1: stack_dir = os.path.join(output_folder, 'image_stacks', args.save_name) print('Sampling images, saving to {}...'.format(args.save_name)) else: stack_dir = os.path.join( output_folder, 'image_stacks', args.save_name + '_' + str(iter_vol).zfill(3)) print('Sampling images, saving to {}_'.format(args.save_name) + str(iter_vol).zfill(3) + '...') if not os.path.exists(stack_dir): os.makedirs(stack_dir) with torch.no_grad(): mean, logs = model.prior(None, None) alpha = 1 - torch.reshape(torch.linspace(0, 1, steps=steps), (-1, 1, 1, 1)) alpha = alpha.to(device) num_imgs = int(np.ceil(hparams['patch_size'] / steps) + 1) z = gaussian_sample(mean, logs, temperature)[:num_imgs, ...] z = torch.cat([ alpha * z[i, ...] + (1 - alpha) * z[i + 1, ...] for i in range(num_imgs - 1) ]) z = z[:hparams['patch_size'], ...] images_raw = model(z=z, temperature=temperature, reverse=True) images_raw[torch.isnan(images_raw)] = 0.5 images_raw[torch.isinf(images_raw)] = 0.5 images_raw = torch.clamp(images_raw, -0.5, 0.5) # apply median filter to output if args.med_filt is not None or args.binary_data: for m in range(args.n_modalities): if args.binary_data: SE = ball(1) else: SE = ball(args.med_filt) images_np = np.squeeze(images_raw[:, m, :, :].cpu().numpy()) images_filt = median_filter(images_np, footprint=SE) # Erode binary images if args.binary_data: images_filt = np.greater(images_filt, 0) SE = ball(1) images_filt = 1.0 * binary_erosion(images_filt, selem=SE) - 0.5 images_raw[:, m, :, :] = torch.tensor(images_filt, device=device) images1 = postprocess(images_raw).cpu() images2 = postprocess(torch.transpose(images_raw, 0, 2)).cpu() images3 = postprocess(torch.transpose(images_raw, 0, 3)).cpu() # apply Otsu thresholding to output if args.save_binary and not args.binary_data: thresh = threshold_otsu(images1.numpy()) images1[images1 < thresh] = 0 images1[images1 > thresh] = 255 images2[images2 < thresh] = 0 images2[images2 > thresh] = 255 images3[images3 < thresh] = 0 images3[images3 > thresh] = 255 # # erode binary images by 1 px to correct for training image transformation # if args.binary_data: # images1 = np.greater(images1.numpy(), 127) # images2 = np.greater(images2.numpy(), 127) # images3 = np.greater(images3.numpy(), 127) # images1 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images1), selem=np.ones((1,2,2))), 1)) # images2 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images2), selem=np.ones((2,1,2))), 1)) # images3 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images3), selem=np.ones((2,2,1))), 1)) # save video for each modality for m in range(args.n_modalities): if args.n_modalities > 1: save_dir = os.path.join(stack_dir, 'modality{}'.format(m)) else: save_dir = stack_dir if not os.path.exists(save_dir): os.makedirs(save_dir) write_video(images1[:, m, :, :], 'xy', hparams, save_dir) write_video(images2[:, m, :, :], 'xz', hparams, save_dir) write_video(images3[:, m, :, :], 'yz', hparams, save_dir) print('Finished!')
def main(kwargs): check_manual_seed(kwargs['seed']) ds = check_dataset(kwargs['dataset'], kwargs['dataroot'], False, kwargs['download']) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=kwargs['batch_size'], shuffle=True, num_workers=kwargs['n_workers'], drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=kwargs['eval_batch_size'], shuffle=False, num_workers=kwargs['n_workers'], drop_last=False) test_iter = cycle(test_loader) train_iter = cycle(train_loader) x_test = torch.cat([test_iter.__next__()[0].to(device) for _ in range(1)], 0) x_train = torch.cat([train_iter.__next__()[0].to(device) for _ in range(1)], 0) if kwargs['pgd_f_project'] == 'l2': f_project = lambda x, x_0: l2_project(x, x_0, t=kwargs['pgd_l2_t']) elif kwargs['pgd_f_project'] == 'none': f_project = None else: raise run_pgd = lambda x, f_loss: pgd(x, f_loss=f_loss, step_size=kwargs['pgd_step_size'], n_steps=kwargs['pgd_n_steps'], f_project=f_project) assert kwargs['saved_model'] print("Loading...") print(kwargs['saved_model']) sample_bpds = [] test_bpds = [] train_bpds = [] idxs = range(0, 20000, 5000) for idx in tqdm(idxs): model = torch.load(os.path.join(os.path.dirname(kwargs['saved_model']), f'ckpt_{idx}.pt')) model.eval() # ipdb.set_trace() # with torch.no_grad(): # fake = generate_from_noise(model, 500) # f_loss = lambda x: model.forward(x, None, return_details=True, correction=False)[0].view(x.size(0),-1).pow(2).sum(-1) * -1 # # f_loss = lambda x: model.forward(x, None, return_details=True, correction=False)[-1][0] * -1 # fake_p = run_pgd(fake, f_loss).detach() # with torch.no_grad(): # bpd = f_loss(fake) # bpd_p = f_loss(fake_p) # plot_bpds(bpd_p, bpd, os.path.join(kwargs['output_dir'], f'sample_bpds_{idx}.png')) # # ipdb.set_trace() # plot_samples(fake, os.path.join(kwargs['output_dir'], f'sample_{idx}.png')) # plot_samples(fake_p, os.path.join(kwargs['output_dir'], f'sample_{idx}_p.png')) # del bpd # del bpd_p # del fake_p bs= 10 x = x_train[:bs] z = get_prior(model,bs, clamp=False) z = torch.autograd.Variable(torch.zeros_like(z)) z.requires_grad_() xs = [x.detach().cpu().clone()] for n in range(kwargs['pgd_n_steps']): x_gen = model(z= z, y_onehot=None, temperature=1, reverse=True,batch_size=0) if n % kwargs['pgd_n_steps']//10 == 0: xs.append(x_gen.detach().cpu().clone()) loss = torch.pow(x-x_gen, 2).view(x.size(0),-1).sum(-1).mean() g = torch.autograd.grad(loss, [z], create_graph=False)[0] g = g / (torch.norm(g)+1e-10) if (g!=g).sum() > 0: ipdb.set_trace() z = z - kwargs['pgd_step_size'] * g # if f_project is not None: # x = f_project(x, x_0) xs = torch.cat(xs,0) grid = make_grid((postprocess(xs.detach().cpu())), nrow=10).permute(1,2,0) plt.figure(figsize=(5,5)) plt.imshow(grid) plt.axis('off') plt.tight_layout() fpath = os.path.join(kwargs['output_dir'], f'recon_{idx}.png') plt.savefig(fpath, bbox_inches='tight', pad_inches=0)
model = model.eval() def norm_ip(img, min, max): img.clamp_(min=min, max=max) img.add_(-min).div_(max - min + 1e-5) def norm_range(t, range): if range is not None: norm_ip(t, range[0], range[1]) else: norm_ip(t, float(t.min()), float(t.max())) evaluator = evaluation_model( "/home/yellow/deep-learning-and-practice/hw7/classifier_weight.pth") test_conditions = get_test_conditions(hparams['dataroot']).cuda() predict_x = postprocess( model(y_onehot=test_conditions, temperature=1, reverse=True)).float() for t in predict_x: # loop over mini-batch dimension norm_range(t, None) score = evaluator.eval(predict_x, test_conditions) save_image(predict_x.float(), f"score{score:.3f}.png") test_conditions = get_new_test_conditions(hparams['dataroot']).cuda() predict_x = postprocess( model(y_onehot=test_conditions, temperature=1, reverse=True)).float() for t in predict_x: # loop over mini-batch dimension norm_range(t, None) newscore = evaluator.eval(predict_x.float(), test_conditions) save_image(predict_x.float(), f"newscore{newscore:.3f}.png")
def main(kwargs): check_manual_seed(kwargs['seed']) ds = check_dataset(kwargs['dataset'], kwargs['dataroot'], False, kwargs['download']) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=kwargs['batch_size'], shuffle=True, num_workers=kwargs['n_workers'], drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=kwargs['eval_batch_size'], shuffle=False, num_workers=kwargs['n_workers'], drop_last=False) test_iter = cycle(test_loader) train_iter = cycle(train_loader) x_test = torch.cat([test_iter.__next__()[0].to(device) for _ in range(4)], 0) x_train = torch.cat( [train_iter.__next__()[0].to(device) for _ in range(4)], 0) assert kwargs['saved_model'] print("Loading...") print(kwargs['saved_model']) sample_bpds = [] test_bpds = [] train_bpds = [] idxs = range(0, 20000, 5000) for idx in tqdm(idxs): model = torch.load( os.path.join(os.path.dirname(kwargs['saved_model']), f'ckpt_{idx}.pt')) with torch.no_grad(): fake = generate_from_noise(model, 500) sample_bpd = torch.stack([ model.forward(fake + 0.01 * torch.randn_like(fake).to(device), None, return_details=True)[-1][0] * -1 for _ in range(30) ]) bpd = model.forward(fake, None, return_details=True)[-1][0] * -1 plot_bpds( sample_bpd, bpd, os.path.join(kwargs['output_dir'], f'sample_bpds_{idx}.png')) sample_bpds.append((sample_bpd, bpd)) grid = make_grid((postprocess(fake.detach().cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.tight_layout() plt.savefig(os.path.join(kwargs['output_dir'], f'sample_{idx}.png'), bbox_inches='tight', pad_inches=0) # with torch.no_grad(): test_bpd = torch.stack([ model.forward( x_test + 0.01 * torch.randn_like(x_test).to(device), None, return_details=True)[-1][0] * -1 for _ in range(30) ]) bpd = model.forward(x_test, None, return_details=True)[-1][0] * -1 plot_bpds( test_bpd, bpd, os.path.join(kwargs['output_dir'], f'test_bpds_{idx}.png')) test_bpds.append((test_bpd, bpd)) train_bpd = torch.stack([ model.forward( x_train + 0.01 * torch.randn_like(x_train).to(device), None, return_details=True)[-1][0] * -1 for _ in range(30) ]) bpd = model.forward(x_train, None, return_details=True)[-1][0] * -1 plot_bpds( train_bpd, bpd, os.path.join(kwargs['output_dir'], f'train_bpds_{idx}.png')) train_bpds.append((train_bpd, bpd)) torch.save((sample_bpds, train_bpds, test_bpds), os.path.join(kwargs['output_dir'], f'bpds.pt')) # sample_bpds, train_bpds, test_bpds = torch.load(os.path.join(kwargs['output_dir'], f'bpds.pt') ) plot_bpd_stats( sample_bpds, os.path.join(kwargs['output_dir'], f'stats_sample_bpds.png')) plot_bpd_stats(train_bpds, os.path.join(kwargs['output_dir'], f'stats_train_bpds.png')) plot_bpd_stats(test_bpds, os.path.join(kwargs['output_dir'], f'stats_test_bpds.png'))
def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) # def generate_from_noise(batch_size): # _, c2, h, w = model.prior_h.shape # c = c2 // 2 # zshape = (batch_size, c, h, w) # randz = torch.autograd.Variable(torch.randn(zshape), requires_grad=True).to(device) # images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size) # return images def generate_from_noise(batch_size): zshape = (batch_size, 32, 1, 1) randz = torch.randn(zshape).to(device) images = model(randz) return images / 2 def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) # Train Disc fake = generate_from_noise(x.size(0)) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) # D_real_accuracy = torch.sum(torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) # D_fake_accuracy = torch.sum(torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty(x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(x.size(0)) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) losses['total_loss'] = G_loss # G-step optimizer.zero_grad() losses['total_loss'].backward() params = list(model.parameters()) gnorm = [p.grad.norm() for p in params] optimizer.step() # if max_grad_clip > 0: # torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) # if max_grad_norm > 0: # torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) if engine.iter_ind % 50 == 0: grid = make_grid((postprocess(fake.detach().cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) grid = make_grid( (postprocess(uniform_binning_correction(x)[0].cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.savefig(os.path.join(output_dir, f'data_{engine.iter_ind}.png')) return losses
def main(kwargs): check_manual_seed(kwargs['seed']) ds = check_dataset(kwargs['dataset'], kwargs['dataroot'], False, kwargs['download']) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=kwargs['batch_size'], shuffle=True, num_workers=kwargs['n_workers'], drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=kwargs['eval_batch_size'], shuffle=False, num_workers=kwargs['n_workers'], drop_last=False) test_iter = cycle(test_loader) train_iter = cycle(train_loader) x_test = torch.cat([test_iter.__next__()[0].to(device) for _ in range(4)], 0) x_train = torch.cat( [train_iter.__next__()[0].to(device) for _ in range(4)], 0) assert kwargs['saved_model'] print("Loading...") print(kwargs['saved_model']) sample_bpds = [] test_bpds = [] train_bpds = [] idxs = range(0, 20000, 1000) for idx in tqdm(idxs): model = torch.load( os.path.join(os.path.dirname(kwargs['saved_model']), f'ckpt_{idx}.pt')) with torch.no_grad(): fake = generate_from_noise(model, 500) z, bpd, y_logits, (prior, logdet) = model.forward(fake, None, return_details=True) sample_bpds.append(bpd) grid = make_grid((postprocess(fake.detach().cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.tight_layout() plt.savefig(os.path.join(kwargs['output_dir'], f'sample_{idx}.png'), bbox_inches='tight', pad_inches=0) # with torch.no_grad(): z, bpd, y_logits, (prior, logdet) = model.forward(x_test, None, return_details=True) test_bpds.append(bpd) z, bpd, y_logits, (prior, logdet) = model.forward(x_train, None, return_details=True) train_bpds.append(bpd) plt.figure(figsize=(10, 7)) plt.clf() def collate_bpds(bpds, idxs): xdata = [] ydata = [] for n, idx in enumerate(idxs): bpd = bpds[n].cpu().numpy() xdata.append(idx * np.ones_like(bpd)) ydata.append(bpd) xdata = np.array(xdata) ydata = np.array(ydata) return xdata, ydata xdata, ydata = collate_bpds(sample_bpds, idxs) ydata = np.clip(ydata, a_min=0, a_max=1e5) plt.scatter(xdata + 200, ydata, c='k', s=10, alpha=.6, label='sample (n=500)') xdata, ydata = collate_bpds(test_bpds, idxs) ydata = np.clip(ydata, a_min=0, a_max=1e5) plt.scatter(xdata + 400, ydata, c='r', s=10, alpha=.6, label='test data (n=2k)') xdata, ydata = collate_bpds(train_bpds, idxs) ydata = np.clip(ydata, a_min=0, a_max=1e5) plt.scatter(xdata + 600, ydata, c='g', s=10, alpha=.6, label='train data (n=2k)') plt.legend() plt.xlabel('Training Iteration') plt.ylabel('Bits Per Dim') plt.tight_layout() plt.savefig(os.path.join(kwargs['output_dir'], f'bpds.png'), bbox_inches='tight', pad_inches=0)
def sample(model): with torch.no_grad(): assert not hparams['y_condition'] y = None images = model(y_onehot=y, temperature=1, reverse=True, batch_size=32) # images = postprocess(model(y_onehot=y, temperature=1, reverse=True)) return images.cpu() # batch_size = 32 # images = [] # N = 10000 # iters = N//batch_size + 1 # for _ in range(iters): # images.append(sample(model)) # # images =torch.stack(images) # # images = postprocess(images[:N]) # # torch.save(images, '10k_samples.pt') # # ipdb.set_trace() images = sample(model) # ipdb.set_trace() grid = make_grid((postprocess(images)[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.savefig(os.path.join(output_folder, 'sample.png'))
def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) real_acc = fake_acc = acc = 0 if weight_gan > 0: fake = generate_from_noise(model, x.size(0), clamp=clamp) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) D_real_accuracy = torch.sum( torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) D_fake_accuracy = torch.sum( torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty( x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(model, x.size(0), clamp=clamp, guard_nans=False) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) # Trace real_acc = D_real_accuracy.item() fake_acc = D_fake_accuracy.item() acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item()) z, nll, y_logits, (prior, logdet) = model.forward(x, None, return_details=True) train_bpd = nll.mean().item() loss = 0 if weight_gan > 0: loss = loss + weight_gan * G_loss if weight_prior > 0: loss = loss + weight_prior * -prior.mean() if weight_logdet > 0: loss = loss + weight_logdet * -logdet.mean() if weight_entropy_reg > 0: _, _, _, (sample_prior, sample_logdet) = model.forward(fake, None, return_details=True) # notice this is actually "decreasing" sample likelihood. loss = loss + weight_entropy_reg * (sample_prior.mean() + sample_logdet.mean()) # Jac Reg if jac_reg_lambda > 0: # Sample x_samples = generate_from_noise(model, args.batch_size, clamp=clamp).detach() x_samples.requires_grad_() z = model.forward(x_samples, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) sample_foward_jac = compute_jacobian_regularizer(x_samples, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) randz = torch.randn(zshape).to(device) randz = torch.autograd.Variable(randz, requires_grad=True) images = model(z=randz, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [randz] + other_zs sample_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # Data x.requires_grad_() z = model.forward(x, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) z.requires_grad_() images = model(z=z, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [z] + other_zs data_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac ) loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac + data_foward_jac + data_inverse_jac) if not eval_only: optimizer.zero_grad() loss.backward() if not db: assert max_grad_clip == max_grad_norm == 0 if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Replace NaN gradient with 0 for p in model.parameters(): if p.requires_grad and p.grad is not None: g = p.grad.data g[g != g] = 0 optimizer.step() if engine.iter_ind % 100 == 0: with torch.no_grad(): fake = generate_from_noise(model, x.size(0), clamp=clamp) z = model.forward(fake, None, return_details=True)[0] print("Z max min") print(z.max().item(), z.min().item()) if (fake != fake).float().sum() > 0: title = 'NaNs' else: title = "Good" grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.title(title) plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) if engine.iter_ind % eval_every == 0: def check_all_zero_except_leading(x): return x % 10**np.floor(np.log10(x)) == 0 if engine.iter_ind == 0 or check_all_zero_except_leading( engine.iter_ind): torch.save( model.state_dict(), os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt')) model.eval() with torch.no_grad(): # Plot recon fpath = os.path.join(output_dir, '_recon', f'recon_{engine.iter_ind}.png') sample_pad = run_recon_evolution( model, generate_from_noise(model, args.batch_size, clamp=clamp).detach(), fpath) print( f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}") pad = run_recon_evolution(model, x_for_recon, fpath) print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}") pad = pad.item() sample_pad = sample_pad.item() # Inception score sample = torch.cat([ generate_from_noise(model, args.batch_size, clamp=clamp) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] sample = sample + .5 if (sample != sample).float().sum() > 0: print("Sample NaNs") raise else: fid = run_fid(x_real_inception.clamp_(0, 1), sample.clamp_(0, 1)) print(f'fid: {fid}, global_iter: {engine.iter_ind}') # Eval BPD eval_bpd = np.mean([ model.forward(x.to(device), None, return_details=True)[1].mean().item() for x, _ in test_loader ]) stats_dict = { 'global_iteration': engine.iter_ind, 'fid': fid, 'train_bpd': train_bpd, 'pad': pad, 'eval_bpd': eval_bpd, 'sample_pad': sample_pad, 'batch_real_acc': real_acc, 'batch_fake_acc': fake_acc, 'batch_acc': acc } iteration_logger.writerow(stats_dict) plot_csv(iteration_logger.filename) model.train() if engine.iter_ind + 2 % svd_every == 0: model.eval() svd_dict = {} ret = utils.computeSVDjacobian(x_for_recon, model) D_for, D_inv = ret['D_for'], ret['D_inv'] cn = float(D_for.max() / D_for.min()) cn_inv = float(D_inv.max() / D_inv.min()) svd_dict['global_iteration'] = engine.iter_ind svd_dict['condition_num'] = cn svd_dict['max_sv'] = float(D_for.max()) svd_dict['min_sv'] = float(D_for.min()) svd_dict['inverse_condition_num'] = cn_inv svd_dict['inverse_max_sv'] = float(D_inv.max()) svd_dict['inverse_min_sv'] = float(D_inv.min()) svd_logger.writerow(svd_dict) # plot_utils.plot_stability_stats(output_dir) # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv') model.train() if eval_only: sys.exit() # Dummy losses['total_loss'] = torch.mean(nll).item() return losses
def main(kwargs): check_manual_seed(kwargs['seed']) ds = check_dataset(kwargs['dataset'], kwargs['dataroot'], False, kwargs['download']) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=kwargs['batch_size'], shuffle=True, num_workers=kwargs['n_workers'], drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=kwargs['eval_batch_size'], shuffle=False, num_workers=kwargs['n_workers'], drop_last=False) test_iter = cycle(test_loader) train_iter = cycle(train_loader) x_test = torch.cat([test_iter.__next__()[0].to(device) for _ in range(1)], 0) x_train = torch.cat( [train_iter.__next__()[0].to(device) for _ in range(1)], 0) if kwargs['pgd_f_project'] == 'l2': f_project = lambda x, x_0: l2_project(x, x_0, t=kwargs['pgd_l2_t']) elif kwargs['pgd_f_project'] == 'none': f_project = None else: raise run_pgd = lambda x, f_loss: pgd(x, f_loss=f_loss, step_size=kwargs['pgd_step_size'], n_steps=kwargs['pgd_n_steps'], f_project=f_project) assert kwargs['saved_model'] print("Loading...") print(kwargs['saved_model']) sample_bpds = [] test_bpds = [] train_bpds = [] idxs = range(0, 20000, 5000) for idx in tqdm(idxs): model = torch.load( os.path.join(os.path.dirname(kwargs['saved_model']), f'ckpt_{idx}.pt')) model.eval() # ipdb.set_trace() # 1 sample x = x_train[:1].repeat(10, 1, 1, 1).clone() z = model.forward(x, None, return_details=True, correction=False)[0] n_grid = 9 samples = [x] for n in range(n_grid): curr_z = z * (1 - float(n) / n_grid) s = model(y_onehot=None, temperature=1, z=curr_z, reverse=True, use_last_split=True) samples.append(s) samples = torch.cat(samples, 0) grid = make_grid((postprocess(samples.detach().cpu())), nrow=10).permute(1, 2, 0) plt.figure(figsize=(5, 5)) plt.imshow(grid) plt.axis('off') plt.tight_layout() fpath = os.path.join(kwargs['output_dir'], f'linear_last_{idx}.png') plt.savefig(fpath, bbox_inches='tight', pad_inches=0) n_grid = 9 z = model.forward(x, None, return_details=True, correction=False)[0] samples = [x] for n in range(n_grid): curr_z = z * (1 - float(n) / n_grid) s = model(y_onehot=None, temperature=1, z=curr_z, reverse=True, use_last_split=False) samples.append(s) samples = torch.cat(samples, 0) grid = make_grid((postprocess(samples.detach().cpu())), nrow=10).permute(1, 2, 0) plt.figure(figsize=(5, 5)) plt.imshow(grid) plt.axis('off') plt.tight_layout() fpath = os.path.join(kwargs['output_dir'], f'linear_no_last_{idx}.png') plt.savefig(fpath, bbox_inches='tight', pad_inches=0) n_grid = 9 z = model.forward(x, None, return_details=True, correction=False)[0] s = model(y_onehot=None, temperature=1, z=z, reverse=True, use_last_split=True) samples = [s] for n in range(n_grid): m = np.logspace( 0, np.log10(1 / z.view(10, -1).pow(2).sum(-1).pow(.5)[0].item()), n_grid)[n] curr_z = z * m s = model(y_onehot=None, temperature=1, z=curr_z, reverse=True, use_last_split=True) samples.append(s.clone()) samples = torch.cat(samples, 0) grid = make_grid((postprocess(samples.detach().cpu())), nrow=10).permute(1, 2, 0) plt.figure(figsize=(5, 5)) plt.imshow(grid) plt.axis('off') plt.tight_layout() fpath = os.path.join(kwargs['output_dir'], f'log_last_{idx}.png') plt.savefig(fpath, bbox_inches='tight', pad_inches=0) z = model.forward(x, None, return_details=True, correction=False)[0] samples = [x] for n in range(n_grid): curr_z = z * np.logspace( 0, np.log10(1 / z.view( 10, -1).pow(2).sum(-1).pow(.5)[0].item()), n_grid)[n] s = model(y_onehot=None, temperature=1, z=curr_z, reverse=True, use_last_split=False) samples.append(s) samples = torch.cat(samples, 0) grid = make_grid((postprocess(samples.detach().cpu())), nrow=10).permute(1, 2, 0) plt.figure(figsize=(5, 5)) plt.imshow(grid) plt.axis('off') plt.tight_layout() fpath = os.path.join(kwargs['output_dir'], f'log_no_last_{idx}.png') plt.savefig(fpath, bbox_inches='tight', pad_inches=0) # random recons n_grid = 9 z = model.forward(x, None, return_details=True, correction=False)[0] samples = [x] for n in range(n_grid): s = model(y_onehot=None, temperature=1, z=z, reverse=True, use_last_split=False) samples.append(s) samples = torch.cat(samples, 0) grid = make_grid((postprocess(samples.detach().cpu())), nrow=10).permute(1, 2, 0) plt.figure(figsize=(5, 5)) plt.imshow(grid) plt.axis('off') plt.tight_layout() fpath = os.path.join(kwargs['output_dir'], f'random_recons_{idx}.png') plt.savefig(fpath, bbox_inches='tight', pad_inches=0)