Esempio n. 1
0
 def reverse_flow(self, z, y_onehot, temperature):
     with torch.no_grad():
         if z is None:
             if self.y_condition:
                 mean, logs = self.prior(z, y_onehot)
                 z_y = gaussian_sample(mean, logs, temperature)
                 z_n = gaussian_sample(self.mean_normal, self.logs_normal,
                                       temperature)
                 z = torch.cat(z_y, z_n)
             else:
                 mean, logs = self.prior(z, y_onehot)
                 z = gaussian_sample(mean, logs, temperature)
             x = self.flow(z, temperature=temperature, reverse=True)
     return x
Esempio n. 2
0
    def reverse_flow(self, z, zn, y_onehot, temperature, no_grad=True):
        if no_grad:
            with torch.no_grad():
                if z is None:
                    mean, logs = self.prior(z, y_onehot)
                    z = gaussian_sample(mean, logs, temperature)
                x = self.flow(z, zn, temperature=temperature, reverse=True)
        else:
            if z is None:
                mean, logs = self.prior(z, y_onehot)
                z = gaussian_sample(mean, logs, temperature)
            x = self.flow(z, zn, temperature=temperature, reverse=True)

        return x
 def reverse_flow(self, z, y_onehot, temperature, batch_size):
     with torch.no_grad():
         if z is None:
             mean, logs = self.prior(z, y_onehot, batch_size)
             z = gaussian_sample(mean, logs, temperature)
         x = self.flow(z, y_onehot, temperature=temperature, reverse=True)
     return x
Esempio n. 4
0
 def reverse_flow(self,
                  z,
                  y_onehot,
                  temperature,
                  use_last_split=False,
                  batch_size=0):
     if z is None:
         mean, logs = self.prior(z, y_onehot, batch_size=batch_size)
         z = gaussian_sample(mean, logs, temperature)
         self._last_z = z.clone()
     if use_last_split:
         for layer in self.flow.splits:
             layer.use_last = True
     x = self.flow(z, temperature=temperature, reverse=True)
     return x
Esempio n. 5
0
def fischer_approximation_from_model(model,
                                     T=1000,
                                     temperature=1,
                                     generated=True,
                                     sampling_dataset=None):
    total_grad = None
    # if generated :
    #     with torch.no_grad():
    #         mean, logs = model.prior(None,batch_size = 64)
    #         z = gaussian_sample(mean, logs, temperature = temperature)
    #         list_img = model.flow(z, temperature=temperature, reverse=True)
    # else :
    #     list_img = []
    #     for k in range(T):
    #         list_img.append(sampling_dataset[k][0].cuda())

    n = 0
    index = 0
    while n < T:
        if generated:
            with torch.no_grad():
                mean, logs = model.prior(None, batch_size=1)
                z = gaussian_sample(mean, logs, temperature=temperature)
                x = model.flow(z, temperature=temperature, reverse=True)[0]
        else:
            x = sampling_dataset[index][0].cuda()
            index += 1

        model.zero_grad()
        _, nll, _ = model(x.unsqueeze(0))
        nll.backward()
        current_grad = []
        for name, param in model.named_parameters():
            if param.grad is not None:
                current_grad.append(-param.grad.view(-1))

        current_grad = torch.cat(current_grad)**2
        if torch.isinf(current_grad).any():
            continue
        n += 1
        if total_grad is None:
            total_grad = copy.deepcopy(current_grad)
        else:
            total_grad = (1 / n + 1) * (n * current_grad + total_grad)
    # ** .75 : power to fischer matrix,
    return 1. / (total_grad + 1e-8)
Esempio n. 6
0
def fischer_approximation(model,
                          T=1000,
                          temperature=1,
                          generated=True,
                          dataset=False):
    print("Fischer Approximation")
    total_grad = None
    if generated:
        with torch.no_grad():
            mean, logs = model.prior(None, batch_size=T)
            z = gaussian_sample(mean, logs, temperature=temperature)
            list_img = model.flow(z, temperature=temperature, reverse=True)
    else:
        list_img = []
        for k in range(T):
            list_img.append(dataset[k][0])
    for x in tqdm.tqdm(list_img):
        model_copy = copy.deepcopy(model)
        model_copy.zero_grad()
        _, nll, _ = model_copy(x.unsqueeze(0))
        nll.backward()
        current_grad = []
        for name_copy, param_copy in model_copy.named_parameters():
            if param_copy.grad is not None:
                current_grad.append(-param_copy.grad.view(-1))

        current_grad = torch.cat(current_grad)**2
        if torch.isinf(current_grad).any():
            T -= 1
            continue
        if total_grad is None:
            total_grad = copy.deepcopy(current_grad)
        else:
            total_grad += current_grad

    return float(T) / (total_grad + 1e-8)
Esempio n. 7
0
def main(args):
    # torch.manual_seed(args.seed)

    # Test loading and sampling
    output_folder = os.path.join('results', args.name)

    with open(os.path.join(output_folder, 'hparams.json')) as json_file:
        hparams = json.load(json_file)

    device = "cpu" if not torch.cuda.is_available() else "cuda:0"
    image_shape = (hparams['patch_size'], hparams['patch_size'],
                   args.n_modalities)
    num_classes = 1

    print('Loading model...')
    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                 hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'],
                 hparams['LU_decomposed'], num_classes, hparams['learn_top'],
                 hparams['y_condition'])

    model_chkpt = torch.load(
        os.path.join(output_folder, 'checkpoints', args.model))
    model.load_state_dict(model_chkpt['model'])
    model.set_actnorm_init()
    model = model.to(device)

    # Build images
    model.eval()
    temperature = args.temperature

    if args.steps is None:  # automatically calculate step size if no step size

        fig_dir = os.path.join(output_folder, 'stepnum_results')
        if not os.path.exists(fig_dir):
            os.mkdir(fig_dir)

        print('No step size entered')

        # Create sample of images to estimate chord length
        with torch.no_grad():
            mean, logs = model.prior(None, None)
            z = gaussian_sample(mean, logs, temperature)
            images_raw = model(z=z, temperature=temperature, reverse=True)
        images_raw[torch.isnan(images_raw)] = 0.5
        images_raw[torch.isinf(images_raw)] = 0.5
        images_raw = torch.clamp(images_raw, -0.5, 0.5)

        images_out = np.transpose(
            np.squeeze(images_raw[:, args.step_modality, :, :].cpu().numpy()),
            (1, 0, 2))

        # Threshold images and compute covariances
        if args.binary_data:
            thresh = 0
        else:
            thresh = threshold_otsu(images_out)
        images_bin = np.greater(images_out, thresh)
        x_cov = two_point_correlation(images_bin, 0)
        y_cov = two_point_correlation(images_bin, 1)

        # Compute chord length
        cov_avg = np.mean(np.mean(np.concatenate((x_cov, y_cov), axis=2),
                                  axis=0),
                          axis=0)
        N = 5
        S20, _ = curve_fit(straight_line_at_origin(cov_avg[0]), range(0, N),
                           cov_avg[0:N])
        l_pore = np.abs(cov_avg[0] / S20)
        steps = int(l_pore)
        print('Calculated step size: {}'.format(steps))

    else:
        print('Using user-entered step size {}...'.format(args.steps))
        steps = args.steps

    # Build desired number of volumes
    for iter_vol in range(args.iter):
        if args.iter == 1:
            stack_dir = os.path.join(output_folder, 'image_stacks',
                                     args.save_name)
            print('Sampling images, saving to {}...'.format(args.save_name))
        else:
            stack_dir = os.path.join(
                output_folder, 'image_stacks',
                args.save_name + '_' + str(iter_vol).zfill(3))
            print('Sampling images, saving to {}_'.format(args.save_name) +
                  str(iter_vol).zfill(3) + '...')
        if not os.path.exists(stack_dir):
            os.makedirs(stack_dir)

        with torch.no_grad():
            mean, logs = model.prior(None, None)
            alpha = 1 - torch.reshape(torch.linspace(0, 1, steps=steps),
                                      (-1, 1, 1, 1))
            alpha = alpha.to(device)

            num_imgs = int(np.ceil(hparams['patch_size'] / steps) + 1)
            z = gaussian_sample(mean, logs, temperature)[:num_imgs, ...]
            z = torch.cat([
                alpha * z[i, ...] + (1 - alpha) * z[i + 1, ...]
                for i in range(num_imgs - 1)
            ])
            z = z[:hparams['patch_size'], ...]

            images_raw = model(z=z, temperature=temperature, reverse=True)

        images_raw[torch.isnan(images_raw)] = 0.5
        images_raw[torch.isinf(images_raw)] = 0.5
        images_raw = torch.clamp(images_raw, -0.5, 0.5)

        # apply median filter to output
        if args.med_filt is not None or args.binary_data:
            for m in range(args.n_modalities):
                if args.binary_data:
                    SE = ball(1)
                else:
                    SE = ball(args.med_filt)
                images_np = np.squeeze(images_raw[:, m, :, :].cpu().numpy())
                images_filt = median_filter(images_np, footprint=SE)

                # Erode binary images
                if args.binary_data:
                    images_filt = np.greater(images_filt, 0)
                    SE = ball(1)
                    images_filt = 1.0 * binary_erosion(images_filt,
                                                       selem=SE) - 0.5

                images_raw[:, m, :, :] = torch.tensor(images_filt,
                                                      device=device)

        images1 = postprocess(images_raw).cpu()
        images2 = postprocess(torch.transpose(images_raw, 0, 2)).cpu()
        images3 = postprocess(torch.transpose(images_raw, 0, 3)).cpu()

        # apply Otsu thresholding to output
        if args.save_binary and not args.binary_data:
            thresh = threshold_otsu(images1.numpy())
            images1[images1 < thresh] = 0
            images1[images1 > thresh] = 255
            images2[images2 < thresh] = 0
            images2[images2 > thresh] = 255
            images3[images3 < thresh] = 0
            images3[images3 > thresh] = 255

        # # erode binary images by 1 px to correct for training image transformation
        # if args.binary_data:
        #     images1 = np.greater(images1.numpy(), 127)
        #     images2 = np.greater(images2.numpy(), 127)
        #     images3 = np.greater(images3.numpy(), 127)

        #     images1 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images1), selem=np.ones((1,2,2))), 1))
        #     images2 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images2), selem=np.ones((2,1,2))), 1))
        #     images3 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images3), selem=np.ones((2,2,1))), 1))

        # save video for each modality
        for m in range(args.n_modalities):
            if args.n_modalities > 1:
                save_dir = os.path.join(stack_dir, 'modality{}'.format(m))
            else:
                save_dir = stack_dir

            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            write_video(images1[:, m, :, :], 'xy', hparams, save_dir)
            write_video(images2[:, m, :, :], 'xz', hparams, save_dir)
            write_video(images3[:, m, :, :], 'yz', hparams, save_dir)

    print('Finished!')
Esempio n. 8
0
def fischer_approximation_from_model(model,
                                     T=1000,
                                     temperature=1,
                                     type_fischer="generated",
                                     sampling_dataset=None):
    fischer_matrix = None
    n = 0

    compteur_empty = 0
    compteur_inf = 0

    index = 0
    indexes = np.arange(0, len(sampling_dataset), step=1)
    random.shuffle(indexes)
    print(f"Fischer with type {type_fischer}")
    while n < T and index < len(sampling_dataset):
        if index % 100 == 0:
            print(f"Index {index} on {len(sampling_dataset)}, n {n} on {T}")
        if type_fischer == "generated":
            with torch.no_grad():
                mean, logs = model.prior(None, batch_size=1)
                z = gaussian_sample(mean, logs, temperature=temperature)
                x = model.flow(z, temperature=temperature, reverse=True)[0]
        elif type_fischer == "sampled":
            x = sampling_dataset[indexes[index]][0].cuda()

        else:
            if n == 0:
                mean, logs = model.prior(None, batch_size=1)
                z = gaussian_sample(mean, logs, temperature=temperature)
                x = model.flow(z, temperature=temperature, reverse=True)[0]
            else:
                return torch.ones(fischer_matrix.flatten().shape[0]).cuda()

        model.zero_grad()
        _, nll, _ = model(x.unsqueeze(0))
        nll.backward()
        current_grad = []
        for _, param in model.named_parameters():
            if param.grad is not None and not torch.isinf(
                    param.grad).any() and not torch.isnan(param.grad).any():
                current_grad.append(-param.grad.view(-1))
        if len(current_grad) == 0:
            print("No grad calculation available")
            compteur_empty += 1
            continue
        current_grad = torch.cat(current_grad)**2
        if torch.isinf(current_grad).any() or torch.isnan(current_grad).any():
            print("Found inf here in current grad")
            compteur_inf += 1
            continue
        if fischer_matrix is None:
            fischer_matrix = copy.deepcopy(current_grad)
        else:
            fischer_matrix = (n / (n + 1)) * fischer_matrix + (
                1 / (n + 1)) * current_grad
        n += 1
        index += 1

    print(f"Number of empty is {compteur_empty}")
    print(f"Number of inf is {compteur_inf}")

    return fischer_matrix + 1e-8