Ejemplo n.º 1
0
        return res

    def __str__(self):
        return (f'SuperResolution.scale_factor={self.scale_factor}'
                f'.mode={self.mode}')


def get_forward_model(fm_name, **fm_kwargs):
    return getattr(sys.modules[__name__], fm_name)(**fm_kwargs)


if __name__ == '__main__':
    from PIL import Image

    from utils import load_target_image
    img = load_target_image('./images/val_celeba128_cropped20/192571.pt', 128)
    img = img.unsqueeze(0)  # batch dim
    mask, ratio = rand_mask(128, 128 * 0.5 / 2)
    print('before', img.shape, mask.shape, ratio, 'min/max:', img.min(),
          img.max())
    masked_img = compress_FFT(img, mask)
    print('during', masked_img.shape, masked_img.min(), masked_img.max())
    rec_img = compress_FFT_t(masked_img, mask)
    print('after', rec_img.shape, rec_img.min(), rec_img.max())

    def to_img(x):
        return (x * 255).to(torch.uint8).squeeze().numpy().transpose([1, 2, 0])

    Image.fromarray(to_img(rec_img)).save('test_fft.img.png')
    Image.fromarray(to_img(mask.squeeze()[..., 0])).save('test_fft.mask0.png')
    Image.fromarray(to_img(mask.squeeze()[..., 1])).save('test_fft.mask1.png')
Ejemplo n.º 2
0
    gen = load_trained_net(
        gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
              '.n_cuts=0/gen_ckpt.49.pt'))
    gen = gen.eval().to(DEVICE)

    n_cuts = 3

    img_size = 128
    img_shape = (3, img_size, img_size)

    forward_model = GaussianCompressiveSensing(n_measure=2500,
                                               img_shape=img_shape)
    # forward_model = NoOp()

    for img_name in tqdm(os.listdir(args.img_dir),
                         desc='Images',
                         leave=True,
                         disable=args.disable_tqdm):
        orig_img = load_target_image(os.path.join(args.img_dir, img_name),
                                     img_size).to(DEVICE)
        img_basename, _ = os.path.splitext(img_name)
        x_hat, x_degraded, _ = recover(orig_img,
                                       gen,
                                       optimizer_type='lbfgs',
                                       n_cuts=n_cuts,
                                       forward_model=forward_model,
                                       z_lr=1.0,
                                       n_steps=25,
                                       run_dir='ours',
                                       run_name=img_basename)
Ejemplo n.º 3
0
def mgan_images(args):
    if args.set_seed:
        torch.manual_seed(0)
        np.random.seed(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    os.makedirs(BASE_DIR, exist_ok=True)

    if args.model in ['mgan_began_cs']:
        gen = Generator128(64)
        gen = load_trained_net(
            gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
                  '.n_cuts=0/gen_ckpt.49.pt'))
        gen = gen.eval().to(DEVICE)
        img_size = 128
    elif args.model in ['mgan_vanilla_vae_cs']:
        gen = VAE()
        t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt')
        gen.load_state_dict(t)
        gen = gen.eval().to(DEVICE)
        gen = gen.decoder
        img_size = 128
    elif args.model in ['mgan_dcgan_cs']:
        gen = dcgan_generator()
        t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64'
                        '.b1_0.5.lr_0.0002.pt'))
        gen.load_state_dict(t)
        gen = gen.eval().to(DEVICE)
        img_size = 64
    else:
        raise NotImplementedError()

    img_shape = (3, img_size, img_size)
    metadata = recovery_settings[args.model]
    n_cuts_list = metadata['n_cuts_list']
    del (metadata['n_cuts_list'])

    z_init_mode_list = metadata['z_init_mode']
    limit_list = metadata['limit']
    assert len(z_init_mode_list) == len(limit_list)
    del (metadata['z_init_mode'])
    del (metadata['limit'])

    forwards = forward_models[args.model]

    data_split = Path(args.img_dir).name
    for img_name in tqdm(sorted(os.listdir(args.img_dir)),
                         desc='Images',
                         leave=True,
                         disable=args.disable_tqdm):
        # Load image and get filename without extension
        orig_img = load_target_image(os.path.join(args.img_dir, img_name),
                                     img_size).to(DEVICE)
        img_basename, _ = os.path.splitext(img_name)

        for n_cuts in tqdm(n_cuts_list,
                           desc='N_cuts',
                           leave=False,
                           disable=args.disable_tqdm):
            metadata['n_cuts'] = n_cuts
            for i, (f, f_args_list) in enumerate(
                    tqdm(forwards.items(),
                         desc='Forwards',
                         leave=False,
                         disable=args.disable_tqdm)):
                for f_args in tqdm(f_args_list,
                                   desc=f'{f} Args',
                                   leave=False,
                                   disable=args.disable_tqdm):

                    f_args['img_shape'] = img_shape
                    forward_model = get_forward_model(f, **f_args)

                    for z_init_mode, limit in zip(
                            tqdm(z_init_mode_list,
                                 desc='z_init_mode',
                                 leave=False), limit_list):
                        metadata['z_init_mode'] = z_init_mode
                        metadata['limit'] = limit

                        # Before doing recovery, check if results already exist
                        # and possibly skip
                        recovered_name = 'recovered.pt'
                        results_folder = get_results_folder(
                            image_name=img_basename,
                            model=args.model,
                            n_cuts=n_cuts,
                            split=data_split,
                            forward_model=forward_model,
                            recovery_params=dict_to_str(metadata),
                            base_dir=BASE_DIR)

                        os.makedirs(results_folder, exist_ok=True)

                        recovered_path = results_folder / recovered_name
                        if os.path.exists(
                                recovered_path) and not args.overwrite:
                            print(
                                f'{recovered_path} already exists, skipping...'
                            )
                            continue

                        if args.run_name is not None:
                            current_run_name = (
                                f'{img_basename}.{forward_model}'
                                f'.{dict_to_str(metadata)}'
                                f'.{args.run_name}')
                        else:
                            current_run_name = None

                        recovered_img, distorted_img, _ = mgan_recover(
                            orig_img, gen, n_cuts, forward_model,
                            metadata['optimizer'], z_init_mode, limit,
                            metadata['z_lr'], metadata['n_steps'],
                            metadata['z_number'], metadata['restarts'],
                            args.run_dir, current_run_name, args.disable_tqdm)

                        # Make images folder
                        img_folder = get_images_folder(split=data_split,
                                                       image_name=img_basename,
                                                       img_size=img_size,
                                                       base_dir=BASE_DIR)
                        os.makedirs(img_folder, exist_ok=True)

                        # Save original image if needed
                        original_img_path = img_folder / 'original.pt'
                        if not os.path.exists(original_img_path):
                            torch.save(orig_img, original_img_path)

                        # Save distorted image if needed
                        if forward_model.viewable:
                            distorted_img_path = img_folder / f'{forward_model}.pt'
                            if not os.path.exists(distorted_img_path):
                                torch.save(distorted_img, distorted_img_path)

                        # Save recovered image and metadata
                        torch.save(recovered_img, recovered_path)
                        pickle.dump(
                            metadata,
                            open(results_folder / 'metadata.pkl', 'wb'))
                        p = psnr(recovered_img, orig_img)
                        pickle.dump(p, open(results_folder / 'psnr.pkl', 'wb'))
Ejemplo n.º 4
0
def deep_decoder_images(args):
    if args.set_seed:
        torch.manual_seed(0)
        np.random.seed(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    os.makedirs(BASE_DIR, exist_ok=True)

    metadata = recovery_settings[args.model]
    forwards = forward_models[args.model]

    data_split = Path(args.img_dir).name
    for img_name in tqdm(sorted(os.listdir(args.img_dir)),
                         desc='Images',
                         leave=True,
                         disable=args.disable_tqdm):
        orig_img = load_target_image(os.path.join(args.img_dir, img_name),
                                     metadata['img_size']).to(DEVICE)
        img_basename, _ = os.path.splitext(img_name)

        for f, f_args_list in tqdm(forwards.items(),
                                   desc='Forwards',
                                   leave=False,
                                   disable=args.disable_tqdm):
            for f_args in tqdm(f_args_list,
                               desc=f'{f} Args',
                               leave=False,
                               disable=args.disable_tqdm):
                f_args['img_shape'] = (3, metadata['img_size'],
                                       metadata['img_size'])
                forward_model = get_forward_model(f, **f_args)

                recovered_name = 'recovered.pt'
                results_folder = get_results_folder(
                    image_name=img_basename,
                    model=args.model,
                    n_cuts=0,  # NOTE - this field is unused for iagan
                    split=data_split,
                    forward_model=forward_model,
                    recovery_params=dict_to_str(metadata),
                    base_dir=BASE_DIR)

                os.makedirs(results_folder, exist_ok=True)

                recovered_path = results_folder / recovered_name
                if os.path.exists(recovered_path) and not args.overwrite:
                    print(f'{recovered_path} already exists, skipping...')
                    continue

                if args.run_name is not None:
                    current_run_name = (f'{img_basename}' +
                                        f'.{forward_model}' +
                                        dict_to_str(metadata) +
                                        f'.{args.run_name}')
                else:
                    current_run_name = None

                recovered_img, distorted_img, _ = deep_decoder_recover(
                    orig_img,
                    forward_model=forward_model,
                    optimizer=metadata['optimizer'],
                    num_filters=metadata['num_filters'],
                    depth=metadata['depth'],
                    lr=metadata['lr'],
                    img_size=metadata['img_size'],
                    steps=metadata['steps'],
                    restarts=metadata['restarts'],
                    run_dir=args.run_dir,
                    run_name=current_run_name,
                    disable_tqdm=args.disable_tqdm)

                # Make images folder
                img_folder = get_images_folder(split=data_split,
                                               image_name=img_basename,
                                               img_size=metadata['img_size'],
                                               base_dir=BASE_DIR)
                os.makedirs(img_folder, exist_ok=True)

                # Save original image if needed
                original_img_path = img_folder / 'original.pt'
                if not os.path.exists(original_img_path):
                    torch.save(orig_img, original_img_path)

                # Save distorted image if needed
                if forward_model.viewable:
                    distorted_img_path = img_folder / f'{forward_model}.pt'
                    if not os.path.exists(distorted_img_path):
                        torch.save(distorted_img, distorted_img_path)

                # Save recovered image and metadata
                torch.save(recovered_img, recovered_path)
                pickle.dump(metadata,
                            open(results_folder / 'metadata.pkl', 'wb'))
                p = psnr(recovered_img, orig_img)
                pickle.dump(p, open(results_folder / 'psnr.pkl', 'wb'))
Ejemplo n.º 5
0
def lasso_cs_images(args):
    if args.set_seed:
        torch.manual_seed(0)
        np.random.seed(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    os.makedirs(BASE_DIR, exist_ok=True)
    if args.model in ['lasso-dct-64', 'lasso-dct-128']:
        recover_fn = recover_dct
    else:
        raise NotImplementedError()

    metadata = baseline_settings[args.model]
    assert len(metadata['n_measure']) == len(metadata['lasso_coeff'])

    data_split = Path(args.img_dir).name
    for img_name in tqdm(sorted(os.listdir(args.img_dir)),
                         desc='Images',
                         leave=True,
                         disable=args.disable_tqdm):
        # Load image and get filename without extension
        orig_img = load_target_image(os.path.join(args.img_dir, img_name),
                                     metadata['img_size']).numpy().transpose(
                                         [1, 2, 0])
        img_basename, _ = os.path.splitext(img_name)

        for n_measure, lasso_coeff in zip(
                tqdm(metadata['n_measure'],
                     desc='N_measure',
                     leave=False,
                     disable=args.disable_tqdm), metadata['lasso_coeff']):

            # Before doing recovery, check if results already exist
            # and possibly skip
            recovered_name = 'recovered.npy'
            results_folder = get_baseline_results_folder(
                image_name=img_basename,
                model=args.model,
                split=data_split,
                n_measure=n_measure,
                lasso_coeff=lasso_coeff,
                base_dir=BASE_DIR)

            os.makedirs(results_folder, exist_ok=True)

            recovered_path = results_folder / recovered_name
            if os.path.exists(recovered_path) and not args.overwrite:
                print(f'{recovered_path} already exists, skipping...')
                continue

            recovered_img = recover_fn(orig_img, n_measure, lasso_coeff,
                                       metadata['img_size'])

            # Make images folder
            img_folder = get_images_folder(split=data_split,
                                           image_name=img_basename,
                                           img_size=metadata['img_size'],
                                           base_dir=BASE_DIR)
            os.makedirs(img_folder, exist_ok=True)

            # Save original image if needed
            original_img_path = img_folder / 'original.npy'
            if not os.path.exists(original_img_path):
                np.save(original_img_path, orig_img)

            # Save recovered image and metadata
            np.save(recovered_path, recovered_img)
            pickle.dump(metadata, open(results_folder / 'metadata.pkl', 'wb'))
            pickle.dump(psnr(recovered_img, orig_img),
                        open(results_folder / 'psnr.pkl', 'wb'))