Пример #1
0
def main(args,kwargs):
    output_folder = args.output_folder
    model_name = args.model_name

    with open(os.path.join(output_folder,'hparams.json')) as json_file:  
        hparams = json.load(json_file)
      
    image_shape, num_classes, _, test_mnist = get_MNIST(False, hparams['dataroot'], hparams['download'])
    test_loader = data.DataLoader(test_mnist, batch_size=32,
                                      shuffle=False, num_workers=6,
                                      drop_last=False)
    x, y = test_loader.__iter__().__next__()
    x = x.to(device)

    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes,
                 hparams['learn_top'], hparams['y_condition'], False if 'logittransform' not in hparams else hparams['logittransform'],False if 'sn' not in hparams else hparams['sn'])

    model.load_state_dict(torch.load(os.path.join(output_folder, model_name)))
    model.set_actnorm_init()

    model = model.to(device)
    model = model.eval()



    with torch.no_grad():
        # ipdb.set_trace()
        images = model(y_onehot=None, temperature=1, batch_size=32, reverse=True).cpu()  
        better_dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True, use_last_split=True).cpu()   
        dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True).cpu()   
        worse_dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True).cpu()   

    l2_err =  torch.pow((images - dup_images).view(images.shape[0], -1), 2).sum(-1).mean()
    better_l2_err =  torch.pow((images - better_dup_images).view(images.shape[0], -1), 2).sum(-1).mean()
    worse_l2_err =  torch.pow((images - worse_dup_images).view(images.shape[0], -1), 2).sum(-1).mean()
    print(l2_err, better_l2_err, worse_l2_err)
    plot_imgs([images, dup_images, better_dup_images, worse_dup_images], '_recons')

    # 
    with torch.no_grad():
        # ipdb.set_trace()
        z, nll, y_logits = model(x, None)
        better_dup_images = model(y_onehot=None, temperature=1, z=z, reverse=True, use_last_split=True).cpu()   

    plot_imgs([x, better_dup_images], '_data_recons2')

    fpath = os.path.join(output_folder, '_recon_evoluation.png')
    pad = run_recon_evolution(model, x, fpath)
    )  #'/home/yellow/deep-learning-and-practice/hw7/dataset/task_2/'
    test_loader = DataLoader(dataset_test,
                             batch_size=Batch_Size,
                             shuffle=False,
                             drop_last=True)
    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                 hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'],
                 hparams['LU_decomposed'], num_classes, hparams['learn_top'],
                 hparams['y_condition'])

    model.load_state_dict(
        torch.load(output_folder + model_name, map_location="cpu")['model'])
    model.set_actnorm_init()
    model = model.to(device)
    model = model.eval()

    # attribute_list = [8] # Black_Hair
    attribute_list = [20, 31, 33]  # Male, Smiling, Wavy_Hair, 24z    No_Beard
    # attribute_list = [11, 26, 31, 8, 6, 7] # Brown_Hair, Pale_Skin, Smiling, Black_Hair, Big_Lips, Big_Nose
    # attribute_list = [i for i in range(40)]
    N = 8
    z_pos_list = [torch.Tensor([]).cuda() for i in range(len(attribute_list))]
    z_neg_list = [torch.Tensor([]).cuda() for i in range(len(attribute_list))]

    z_input_img = None
    with torch.no_grad():
        for i, (x, y) in enumerate(test_loader):
            print('reading data: ', i, '/', 30000 / Batch_Size)
            if i >= 3000: break  # 3000
            for j, attribute_num in enumerate(attribute_list):
Пример #3
0
def main(args):
    # torch.manual_seed(args.seed)

    # Test loading and sampling
    output_folder = os.path.join('results', args.name)

    with open(os.path.join(output_folder, 'hparams.json')) as json_file:
        hparams = json.load(json_file)

    device = "cpu" if not torch.cuda.is_available() else "cuda:0"
    image_shape = (hparams['patch_size'], hparams['patch_size'],
                   args.n_modalities)
    num_classes = 1

    print('Loading model...')
    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                 hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'],
                 hparams['LU_decomposed'], num_classes, hparams['learn_top'],
                 hparams['y_condition'])

    model_chkpt = torch.load(
        os.path.join(output_folder, 'checkpoints', args.model))
    model.load_state_dict(model_chkpt['model'])
    model.set_actnorm_init()
    model = model.to(device)

    # Build images
    model.eval()
    temperature = args.temperature

    if args.steps is None:  # automatically calculate step size if no step size

        fig_dir = os.path.join(output_folder, 'stepnum_results')
        if not os.path.exists(fig_dir):
            os.mkdir(fig_dir)

        print('No step size entered')

        # Create sample of images to estimate chord length
        with torch.no_grad():
            mean, logs = model.prior(None, None)
            z = gaussian_sample(mean, logs, temperature)
            images_raw = model(z=z, temperature=temperature, reverse=True)
        images_raw[torch.isnan(images_raw)] = 0.5
        images_raw[torch.isinf(images_raw)] = 0.5
        images_raw = torch.clamp(images_raw, -0.5, 0.5)

        images_out = np.transpose(
            np.squeeze(images_raw[:, args.step_modality, :, :].cpu().numpy()),
            (1, 0, 2))

        # Threshold images and compute covariances
        if args.binary_data:
            thresh = 0
        else:
            thresh = threshold_otsu(images_out)
        images_bin = np.greater(images_out, thresh)
        x_cov = two_point_correlation(images_bin, 0)
        y_cov = two_point_correlation(images_bin, 1)

        # Compute chord length
        cov_avg = np.mean(np.mean(np.concatenate((x_cov, y_cov), axis=2),
                                  axis=0),
                          axis=0)
        N = 5
        S20, _ = curve_fit(straight_line_at_origin(cov_avg[0]), range(0, N),
                           cov_avg[0:N])
        l_pore = np.abs(cov_avg[0] / S20)
        steps = int(l_pore)
        print('Calculated step size: {}'.format(steps))

    else:
        print('Using user-entered step size {}...'.format(args.steps))
        steps = args.steps

    # Build desired number of volumes
    for iter_vol in range(args.iter):
        if args.iter == 1:
            stack_dir = os.path.join(output_folder, 'image_stacks',
                                     args.save_name)
            print('Sampling images, saving to {}...'.format(args.save_name))
        else:
            stack_dir = os.path.join(
                output_folder, 'image_stacks',
                args.save_name + '_' + str(iter_vol).zfill(3))
            print('Sampling images, saving to {}_'.format(args.save_name) +
                  str(iter_vol).zfill(3) + '...')
        if not os.path.exists(stack_dir):
            os.makedirs(stack_dir)

        with torch.no_grad():
            mean, logs = model.prior(None, None)
            alpha = 1 - torch.reshape(torch.linspace(0, 1, steps=steps),
                                      (-1, 1, 1, 1))
            alpha = alpha.to(device)

            num_imgs = int(np.ceil(hparams['patch_size'] / steps) + 1)
            z = gaussian_sample(mean, logs, temperature)[:num_imgs, ...]
            z = torch.cat([
                alpha * z[i, ...] + (1 - alpha) * z[i + 1, ...]
                for i in range(num_imgs - 1)
            ])
            z = z[:hparams['patch_size'], ...]

            images_raw = model(z=z, temperature=temperature, reverse=True)

        images_raw[torch.isnan(images_raw)] = 0.5
        images_raw[torch.isinf(images_raw)] = 0.5
        images_raw = torch.clamp(images_raw, -0.5, 0.5)

        # apply median filter to output
        if args.med_filt is not None or args.binary_data:
            for m in range(args.n_modalities):
                if args.binary_data:
                    SE = ball(1)
                else:
                    SE = ball(args.med_filt)
                images_np = np.squeeze(images_raw[:, m, :, :].cpu().numpy())
                images_filt = median_filter(images_np, footprint=SE)

                # Erode binary images
                if args.binary_data:
                    images_filt = np.greater(images_filt, 0)
                    SE = ball(1)
                    images_filt = 1.0 * binary_erosion(images_filt,
                                                       selem=SE) - 0.5

                images_raw[:, m, :, :] = torch.tensor(images_filt,
                                                      device=device)

        images1 = postprocess(images_raw).cpu()
        images2 = postprocess(torch.transpose(images_raw, 0, 2)).cpu()
        images3 = postprocess(torch.transpose(images_raw, 0, 3)).cpu()

        # apply Otsu thresholding to output
        if args.save_binary and not args.binary_data:
            thresh = threshold_otsu(images1.numpy())
            images1[images1 < thresh] = 0
            images1[images1 > thresh] = 255
            images2[images2 < thresh] = 0
            images2[images2 > thresh] = 255
            images3[images3 < thresh] = 0
            images3[images3 > thresh] = 255

        # # erode binary images by 1 px to correct for training image transformation
        # if args.binary_data:
        #     images1 = np.greater(images1.numpy(), 127)
        #     images2 = np.greater(images2.numpy(), 127)
        #     images3 = np.greater(images3.numpy(), 127)

        #     images1 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images1), selem=np.ones((1,2,2))), 1))
        #     images2 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images2), selem=np.ones((2,1,2))), 1))
        #     images3 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images3), selem=np.ones((2,2,1))), 1))

        # save video for each modality
        for m in range(args.n_modalities):
            if args.n_modalities > 1:
                save_dir = os.path.join(stack_dir, 'modality{}'.format(m))
            else:
                save_dir = stack_dir

            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            write_video(images1[:, m, :, :], 'xy', hparams, save_dir)
            write_video(images2[:, m, :, :], 'xz', hparams, save_dir)
            write_video(images3[:, m, :, :], 'yz', hparams, save_dir)

    print('Finished!')
Пример #4
0
def main(args):
    seed_list = [int(item) for item in args.seed.split(',')]

    for seed in seed_list:

        device = torch.device("cuda")

        experiment_folder = args.experiment_folder + '/' + str(seed) + '/'
        print(experiment_folder)
        #model_name = 'glow_checkpoint_'+ str(args.chk)+'.pth'
        for thing in os.listdir(experiment_folder):
            if 'best' in thing:
                model_name = thing
        print(model_name)

        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        torch.random.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

        with open(experiment_folder + 'hparams.json') as json_file:
            hparams = json.load(json_file)

        image_shape = (32, 32, 3)
        if hparams['y_condition']:
            num_classes = 2
            num_domains = 0
        elif hparams['d_condition']:
            num_classes = 10
            num_domains = 0
        elif hparams['yd_condition']:
            num_classes = 2
            num_domains = 10
        else:
            num_classes = 2
            num_domains = 0

        model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                     hparams['L'], hparams['actnorm_scale'],
                     hparams['flow_permutation'], hparams['flow_coupling'],
                     hparams['LU_decomposed'], num_classes, num_domains,
                     hparams['learn_top'], hparams['y_condition'],
                     hparams['extra_condition'], hparams['sp_condition'],
                     hparams['d_condition'], hparams['yd_condition'])
        print('loading model')
        model.load_state_dict(torch.load(experiment_folder + model_name))
        model.set_actnorm_init()

        model = model.to(device)

        model = model.eval()

        if hparams['y_condition']:
            print('y_condition')

            def sample(model, temp=args.temperature):
                with torch.no_grad():
                    if hparams['y_condition']:
                        print("extra", hparams['extra_condition'])
                        y = torch.eye(num_classes)
                        y = torch.cat(1000 * [y])
                        print(y.size())
                        y_0 = y[::2, :].to(
                            device)  # number hardcoded in model for now
                        y_1 = y[1::2, :].to(device)
                        print(y_0.size())
                        print(y_0)
                        print(y_1)
                        print(y_1.size())
                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        images1 = model(z=None,
                                        y_onehot=y_1,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                return images0, images1

            images0, images1 = sample(model)

            os.makedirs(experiment_folder + 'generations/Uninfected',
                        exist_ok=True)
            os.makedirs(experiment_folder + 'generations/Parasitized',
                        exist_ok=True)
            for i in range(images0.size(0)):
                torchvision.utils.save_image(
                    images0[i, :, :, :], experiment_folder +
                    'generations/Uninfected/sample_{}.png'.format(i))
                torchvision.utils.save_image(
                    images1[i, :, :, :], experiment_folder +
                    'generations/Parasitized/sample_{}.png'.format(i))
            images_concat0 = torchvision.utils.make_grid(images0[:64, :, :, :],
                                                         nrow=int(64**0.5),
                                                         padding=2,
                                                         pad_value=255)
            torchvision.utils.save_image(images_concat0,
                                         experiment_folder + '/uninfected.png')
            images_concat1 = torchvision.utils.make_grid(images1[:64, :, :, :],
                                                         nrow=int(64**0.5),
                                                         padding=2,
                                                         pad_value=255)
            torchvision.utils.save_image(
                images_concat1, experiment_folder + '/parasitized.png')

        elif hparams['d_condition']:
            print('d_cond')

            def sample_d(model, idx, batch_size=1000, temp=args.temperature):
                with torch.no_grad():
                    if hparams['d_condition']:

                        y_0 = torch.zeros([batch_size, 10], device='cuda:0')
                        y_0[:, idx] = torch.ones(batch_size)
                        y_0.to(device)
                        print(y_0)

                        # y_1 = torch.zeros([batch_size, 201], device='cuda:0')
                        # y_1[:, 157] = torch.ones(batch_size)
                        # y_1.to(device)
                        # y = torch.eye(num_classes)
                        # y = torch.cat(1000 * [y])
                        # print(y.size())
                        # y_0 = y[::2, :].to(device)  # number hardcoded in model for now
                        # y_1 = y[1::2, :].to(device)
                        # print(y_0.size())
                        # print(y_0)
                        # print(y_1)
                        # print(y_1.size())

                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        # images1 = model(z=None, y_onehot=y_1, temperature=1.0, reverse=True, batch_size=1000)
                return images0

            for idx, dom in enumerate(["C116P77ThinF", "C132P93ThinF", "C137P98ThinF", "C180P141NThinF", "C182P143NThinF", \
                             "C184P145ThinF", "C39P4thinF", 'C59P20thinF', "C68P29N", "C99P60ThinF"]):

                images0 = sample_d(model, idx)

                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Uninfected/',
                            exist_ok=True)
                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Parasitized/',
                            exist_ok=True)
                # os.makedirs(experiment_folder + 'generations/C59P20thinF/Uninfected/', exist_ok=True)
                # os.makedirs(experiment_folder + 'generations/C59P20thinF/Parasitized/', exist_ok=True)
                for i in range(images0.size(0)):
                    torchvision.utils.save_image(
                        images0[i, :, :, :],
                        experiment_folder + 'generations/' + dom +
                        '/Uninfected/sample_{}.png'.format(i))
                    #torchvision.utils.save_image(images1[i, :, :, :], experiment_folder + 'generations/C59P20thinF/Parasitized/sample_{}.png'.format(i))
                images_concat0 = torchvision.utils.make_grid(
                    images0[:25, :, :, :],
                    nrow=int(25**0.5),
                    padding=2,
                    pad_value=255)
                torchvision.utils.save_image(images_concat0,
                                             experiment_folder + dom + '.png')
                # images_concat1 = torchvision.utils.make_grid(images1[:64,:,:,:], nrow=int(64 ** 0.5), padding=2, pad_value=255)
                # torchvision.utils.save_image(images_concat1, experiment_folder + 'C59P20thinF.png')

        elif hparams['yd_condition']:

            def sample_YD(model, idx, batch_size=1000, temp=args.temperature):
                with torch.no_grad():
                    if hparams['yd_condition']:
                        y_0 = torch.zeros([batch_size, 12], device='cuda:0')
                        y_0[:, 0] = torch.ones(batch_size)
                        y_0[:, idx + 2] = torch.ones(batch_size)
                        y_0.to(device)
                        print(y_0)

                        y_1 = torch.zeros([batch_size, 12], device='cuda:0')
                        y_1[:, 1] = torch.ones(batch_size)
                        y_1[:, idx + 2] = torch.ones(batch_size)
                        y_1.to(device)
                        print(y_1)

                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        images1 = model(z=None,
                                        y_onehot=y_1,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                return images0, images1

            def sample_DD(model, idx, batch_size=1000, temp=args.temperature):
                with torch.no_grad():
                    if hparams['yd_condition']:
                        y_1 = torch.zeros([batch_size, 20], device='cuda:0')
                        y_1[:, idx] = torch.ones(batch_size)
                        y_1.to(device)
                        print(y_1)

                        y_0 = torch.zeros([batch_size, 20], device='cuda:0')
                        y_0[:, idx + 10] = torch.ones(batch_size)
                        y_0.to(device)
                        print(y_0)

                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        images1 = model(z=None,
                                        y_onehot=y_1,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                return images0, images1

            for idx, dom in enumerate(
                    ["C116P77ThinF", "C132P93ThinF", "C137P98ThinF", "C180P141NThinF", "C182P143NThinF", \
                     "C184P145ThinF", "C39P4thinF", 'C59P20thinF', "C68P29N", "C99P60ThinF"]):

                images0, images1 = sample_YD(model, idx)

                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Uninfected/',
                            exist_ok=True)
                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Parasitized/',
                            exist_ok=True)
                for i in range(images0.size(0)):
                    torchvision.utils.save_image(
                        images0[i, :, :, :],
                        experiment_folder + 'generations/' + dom +
                        '/Uninfected/sample_{}.png'.format(i))
                    torchvision.utils.save_image(
                        images1[i, :, :, :],
                        experiment_folder + 'generations/' + dom +
                        '/Parasitized/sample_{}.png'.format(i))
                images_concat0 = torchvision.utils.make_grid(
                    images0[:64, :, :, :],
                    nrow=int(64**0.5),
                    padding=2,
                    pad_value=255)
                torchvision.utils.save_image(
                    images_concat0, experiment_folder + dom +
                    str(args.temperature) + '_uninfected.png')
                images_concat1 = torchvision.utils.make_grid(
                    images1[:64, :, :, :],
                    nrow=int(64**0.5),
                    padding=2,
                    pad_value=255)
                torchvision.utils.save_image(
                    images_concat1, experiment_folder + dom +
                    str(args.temperature) + '_parasitized.png')

        else:

            def sample(model, temp=args.temperature):
                with torch.no_grad():
                    images = model(z=None,
                                   y_onehot=None,
                                   temperature=temp,
                                   reverse=True,
                                   batch_size=1000)

                return images

            images = sample(model)

            os.makedirs('unconditioned/' + str(seed) + '/generations/' +
                        experiment_folder[:-3],
                        exist_ok=True)
            for i in range(images.size(0)):
                torchvision.utils.save_image(
                    images[i, :, :, :],
                    'unconditioned/' + str(seed) + '/generations/' +
                    experiment_folder[:-2] + 'sample_{}.png'.format(i))

            images_concat = torchvision.utils.make_grid(images[:64, :, :, :],
                                                        nrow=int(64**0.5),
                                                        padding=2,
                                                        pad_value=255)
            torchvision.utils.save_image(
                images_concat, 'unconditioned/' + str(seed) + '/' +
                experiment_folder[:-3] + '.png')
Пример #5
0
def main(dataset, dataroot, download, augment, n_workers, eval_batch_size, output_dir,db, glow_path,ckpt_name):

    
    (image_shape, num_classes, train_dataset, test_dataset) = check_dataset(dataset, dataroot, augment, download)

    test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size,
                                  shuffle=False, num_workers=n_workers,
                                  drop_last=False)

    x = test_loader.__iter__().__next__()[0].to(device)

    # OOD data
    ood_distributions = ['gaussian']
    # ood_distributions = ['gaussian', 'rademacher', 'texture3', 'svhn','tinyimagenet','lsun']
    tr = transforms.Compose([])
    tr.transforms.append(transforms.ToPILImage()) 
    tr.transforms.append(transforms.Resize((32,32)))
    tr.transforms.append(transforms.ToTensor())
    tr.transforms.append(one_to_three_channels)
    tr.transforms.append(preprocess)
    ood_tensors = [(out_name, torch.stack([tr(x) for x in load_ood_data({
                                  'name': out_name,
                                  'ood_scale': 1,
                                  'n_anom': eval_batch_size,
                                })]).to(device)
                        ) for out_name in ood_distributions]
    if 'sd' in glow_path:
        with open(os.path.join(os.path.dirname(glow_path), 'hparams.json'), 'r') as f:
            model_kwargs = json.load(f)
        model = Glow(
                (32, 32, 3), 
                model_kwargs['hidden_channels'], 
                model_kwargs['K'], 
                model_kwargs['L'], 
                model_kwargs['actnorm_scale'],
                model_kwargs['flow_permutation'], 
                model_kwargs['flow_coupling'], 
                model_kwargs['LU_decomposed'], 
                10,
                model_kwargs['learn_top'], 
                model_kwargs['y_condition'],
                model_kwargs['logittransform'],
                model_kwargs['sn'],
                model_kwargs['affine_eps'],
                model_kwargs['no_actnorm'],
                model_kwargs['affine_scale_eps'], 
                model_kwargs['actnorm_max_scale'], 
                model_kwargs['no_conv_actnorm'],
                model_kwargs['affine_max_scale'],
                model_kwargs['actnorm_eps'],
                model_kwargs['no_split']
            )
        model.load_state_dict(torch.load(glow_path))
        model.set_actnorm_init()
    else:
        model = torch.load(glow_path)
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        samples = generate_from_noise(model, eval_batch_size,clamp=False, guard_nans=False)
    stats = OrderedDict()
    for name, x in [('data',x), ('samples',samples)] + ood_tensors:
        p_pxs, p_ims, cn, dlogdet, bpd, pad = run_analysis(x, model, os.path.join(output_dir, f'recon_{ckpt_name}_{name}.jpeg'))
        
        stats[f"{name}-percent-pixels-nans"] =  p_pxs
        stats[f"{name}-percent-imgs-nans"] =  p_ims
        stats[f"{name}-cn"] =  cn
        stats[f"{name}-dlogdet"] =  dlogdet
        stats[f"{name}-bpd"] =  bpd
        stats[f"{name}-recon-err"] =  pad
        
        with open(os.path.join(output_dir, f'results_{ckpt_name}.json'), 'w') as fp:
            json.dump(stats, fp, indent=4)