def reverse_flow(self, z, y_onehot, temperature): with torch.no_grad(): if z is None: if self.y_condition: mean, logs = self.prior(z, y_onehot) z_y = gaussian_sample(mean, logs, temperature) z_n = gaussian_sample(self.mean_normal, self.logs_normal, temperature) z = torch.cat(z_y, z_n) else: mean, logs = self.prior(z, y_onehot) z = gaussian_sample(mean, logs, temperature) x = self.flow(z, temperature=temperature, reverse=True) return x
def reverse_flow(self, z, zn, y_onehot, temperature, no_grad=True): if no_grad: with torch.no_grad(): if z is None: mean, logs = self.prior(z, y_onehot) z = gaussian_sample(mean, logs, temperature) x = self.flow(z, zn, temperature=temperature, reverse=True) else: if z is None: mean, logs = self.prior(z, y_onehot) z = gaussian_sample(mean, logs, temperature) x = self.flow(z, zn, temperature=temperature, reverse=True) return x
def reverse_flow(self, z, y_onehot, temperature, batch_size): with torch.no_grad(): if z is None: mean, logs = self.prior(z, y_onehot, batch_size) z = gaussian_sample(mean, logs, temperature) x = self.flow(z, y_onehot, temperature=temperature, reverse=True) return x
def reverse_flow(self, z, y_onehot, temperature, use_last_split=False, batch_size=0): if z is None: mean, logs = self.prior(z, y_onehot, batch_size=batch_size) z = gaussian_sample(mean, logs, temperature) self._last_z = z.clone() if use_last_split: for layer in self.flow.splits: layer.use_last = True x = self.flow(z, temperature=temperature, reverse=True) return x
def fischer_approximation_from_model(model, T=1000, temperature=1, generated=True, sampling_dataset=None): total_grad = None # if generated : # with torch.no_grad(): # mean, logs = model.prior(None,batch_size = 64) # z = gaussian_sample(mean, logs, temperature = temperature) # list_img = model.flow(z, temperature=temperature, reverse=True) # else : # list_img = [] # for k in range(T): # list_img.append(sampling_dataset[k][0].cuda()) n = 0 index = 0 while n < T: if generated: with torch.no_grad(): mean, logs = model.prior(None, batch_size=1) z = gaussian_sample(mean, logs, temperature=temperature) x = model.flow(z, temperature=temperature, reverse=True)[0] else: x = sampling_dataset[index][0].cuda() index += 1 model.zero_grad() _, nll, _ = model(x.unsqueeze(0)) nll.backward() current_grad = [] for name, param in model.named_parameters(): if param.grad is not None: current_grad.append(-param.grad.view(-1)) current_grad = torch.cat(current_grad)**2 if torch.isinf(current_grad).any(): continue n += 1 if total_grad is None: total_grad = copy.deepcopy(current_grad) else: total_grad = (1 / n + 1) * (n * current_grad + total_grad) # ** .75 : power to fischer matrix, return 1. / (total_grad + 1e-8)
def fischer_approximation(model, T=1000, temperature=1, generated=True, dataset=False): print("Fischer Approximation") total_grad = None if generated: with torch.no_grad(): mean, logs = model.prior(None, batch_size=T) z = gaussian_sample(mean, logs, temperature=temperature) list_img = model.flow(z, temperature=temperature, reverse=True) else: list_img = [] for k in range(T): list_img.append(dataset[k][0]) for x in tqdm.tqdm(list_img): model_copy = copy.deepcopy(model) model_copy.zero_grad() _, nll, _ = model_copy(x.unsqueeze(0)) nll.backward() current_grad = [] for name_copy, param_copy in model_copy.named_parameters(): if param_copy.grad is not None: current_grad.append(-param_copy.grad.view(-1)) current_grad = torch.cat(current_grad)**2 if torch.isinf(current_grad).any(): T -= 1 continue if total_grad is None: total_grad = copy.deepcopy(current_grad) else: total_grad += current_grad return float(T) / (total_grad + 1e-8)
def main(args): # torch.manual_seed(args.seed) # Test loading and sampling output_folder = os.path.join('results', args.name) with open(os.path.join(output_folder, 'hparams.json')) as json_file: hparams = json.load(json_file) device = "cpu" if not torch.cuda.is_available() else "cuda:0" image_shape = (hparams['patch_size'], hparams['patch_size'], args.n_modalities) num_classes = 1 print('Loading model...') model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, hparams['learn_top'], hparams['y_condition']) model_chkpt = torch.load( os.path.join(output_folder, 'checkpoints', args.model)) model.load_state_dict(model_chkpt['model']) model.set_actnorm_init() model = model.to(device) # Build images model.eval() temperature = args.temperature if args.steps is None: # automatically calculate step size if no step size fig_dir = os.path.join(output_folder, 'stepnum_results') if not os.path.exists(fig_dir): os.mkdir(fig_dir) print('No step size entered') # Create sample of images to estimate chord length with torch.no_grad(): mean, logs = model.prior(None, None) z = gaussian_sample(mean, logs, temperature) images_raw = model(z=z, temperature=temperature, reverse=True) images_raw[torch.isnan(images_raw)] = 0.5 images_raw[torch.isinf(images_raw)] = 0.5 images_raw = torch.clamp(images_raw, -0.5, 0.5) images_out = np.transpose( np.squeeze(images_raw[:, args.step_modality, :, :].cpu().numpy()), (1, 0, 2)) # Threshold images and compute covariances if args.binary_data: thresh = 0 else: thresh = threshold_otsu(images_out) images_bin = np.greater(images_out, thresh) x_cov = two_point_correlation(images_bin, 0) y_cov = two_point_correlation(images_bin, 1) # Compute chord length cov_avg = np.mean(np.mean(np.concatenate((x_cov, y_cov), axis=2), axis=0), axis=0) N = 5 S20, _ = curve_fit(straight_line_at_origin(cov_avg[0]), range(0, N), cov_avg[0:N]) l_pore = np.abs(cov_avg[0] / S20) steps = int(l_pore) print('Calculated step size: {}'.format(steps)) else: print('Using user-entered step size {}...'.format(args.steps)) steps = args.steps # Build desired number of volumes for iter_vol in range(args.iter): if args.iter == 1: stack_dir = os.path.join(output_folder, 'image_stacks', args.save_name) print('Sampling images, saving to {}...'.format(args.save_name)) else: stack_dir = os.path.join( output_folder, 'image_stacks', args.save_name + '_' + str(iter_vol).zfill(3)) print('Sampling images, saving to {}_'.format(args.save_name) + str(iter_vol).zfill(3) + '...') if not os.path.exists(stack_dir): os.makedirs(stack_dir) with torch.no_grad(): mean, logs = model.prior(None, None) alpha = 1 - torch.reshape(torch.linspace(0, 1, steps=steps), (-1, 1, 1, 1)) alpha = alpha.to(device) num_imgs = int(np.ceil(hparams['patch_size'] / steps) + 1) z = gaussian_sample(mean, logs, temperature)[:num_imgs, ...] z = torch.cat([ alpha * z[i, ...] + (1 - alpha) * z[i + 1, ...] for i in range(num_imgs - 1) ]) z = z[:hparams['patch_size'], ...] images_raw = model(z=z, temperature=temperature, reverse=True) images_raw[torch.isnan(images_raw)] = 0.5 images_raw[torch.isinf(images_raw)] = 0.5 images_raw = torch.clamp(images_raw, -0.5, 0.5) # apply median filter to output if args.med_filt is not None or args.binary_data: for m in range(args.n_modalities): if args.binary_data: SE = ball(1) else: SE = ball(args.med_filt) images_np = np.squeeze(images_raw[:, m, :, :].cpu().numpy()) images_filt = median_filter(images_np, footprint=SE) # Erode binary images if args.binary_data: images_filt = np.greater(images_filt, 0) SE = ball(1) images_filt = 1.0 * binary_erosion(images_filt, selem=SE) - 0.5 images_raw[:, m, :, :] = torch.tensor(images_filt, device=device) images1 = postprocess(images_raw).cpu() images2 = postprocess(torch.transpose(images_raw, 0, 2)).cpu() images3 = postprocess(torch.transpose(images_raw, 0, 3)).cpu() # apply Otsu thresholding to output if args.save_binary and not args.binary_data: thresh = threshold_otsu(images1.numpy()) images1[images1 < thresh] = 0 images1[images1 > thresh] = 255 images2[images2 < thresh] = 0 images2[images2 > thresh] = 255 images3[images3 < thresh] = 0 images3[images3 > thresh] = 255 # # erode binary images by 1 px to correct for training image transformation # if args.binary_data: # images1 = np.greater(images1.numpy(), 127) # images2 = np.greater(images2.numpy(), 127) # images3 = np.greater(images3.numpy(), 127) # images1 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images1), selem=np.ones((1,2,2))), 1)) # images2 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images2), selem=np.ones((2,1,2))), 1)) # images3 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images3), selem=np.ones((2,2,1))), 1)) # save video for each modality for m in range(args.n_modalities): if args.n_modalities > 1: save_dir = os.path.join(stack_dir, 'modality{}'.format(m)) else: save_dir = stack_dir if not os.path.exists(save_dir): os.makedirs(save_dir) write_video(images1[:, m, :, :], 'xy', hparams, save_dir) write_video(images2[:, m, :, :], 'xz', hparams, save_dir) write_video(images3[:, m, :, :], 'yz', hparams, save_dir) print('Finished!')
def fischer_approximation_from_model(model, T=1000, temperature=1, type_fischer="generated", sampling_dataset=None): fischer_matrix = None n = 0 compteur_empty = 0 compteur_inf = 0 index = 0 indexes = np.arange(0, len(sampling_dataset), step=1) random.shuffle(indexes) print(f"Fischer with type {type_fischer}") while n < T and index < len(sampling_dataset): if index % 100 == 0: print(f"Index {index} on {len(sampling_dataset)}, n {n} on {T}") if type_fischer == "generated": with torch.no_grad(): mean, logs = model.prior(None, batch_size=1) z = gaussian_sample(mean, logs, temperature=temperature) x = model.flow(z, temperature=temperature, reverse=True)[0] elif type_fischer == "sampled": x = sampling_dataset[indexes[index]][0].cuda() else: if n == 0: mean, logs = model.prior(None, batch_size=1) z = gaussian_sample(mean, logs, temperature=temperature) x = model.flow(z, temperature=temperature, reverse=True)[0] else: return torch.ones(fischer_matrix.flatten().shape[0]).cuda() model.zero_grad() _, nll, _ = model(x.unsqueeze(0)) nll.backward() current_grad = [] for _, param in model.named_parameters(): if param.grad is not None and not torch.isinf( param.grad).any() and not torch.isnan(param.grad).any(): current_grad.append(-param.grad.view(-1)) if len(current_grad) == 0: print("No grad calculation available") compteur_empty += 1 continue current_grad = torch.cat(current_grad)**2 if torch.isinf(current_grad).any() or torch.isnan(current_grad).any(): print("Found inf here in current grad") compteur_inf += 1 continue if fischer_matrix is None: fischer_matrix = copy.deepcopy(current_grad) else: fischer_matrix = (n / (n + 1)) * fischer_matrix + ( 1 / (n + 1)) * current_grad n += 1 index += 1 print(f"Number of empty is {compteur_empty}") print(f"Number of inf is {compteur_inf}") return fischer_matrix + 1e-8