예제 #1
0
    def reset_gen():
        if args.model in ['iagan_began_cs']:
            gen = Generator128(64)
            gen = load_trained_net(
                gen,
                ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
                 '.n_cuts=0/gen_ckpt.49.pt'))
            gen = gen.eval().to(DEVICE)
            img_size = 128
        elif args.model in ['iagan_dcgan_cs']:
            gen = dcgan_generator()
            t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64'
                            '.b1_0.5.lr_0.0002.pt'))
            gen.load_state_dict(t)
            gen = gen.eval().to(DEVICE)
            img_size = 64

        elif args.model in ['iagan_vanilla_vae_cs']:
            gen = VAE()
            t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt')
            gen.load_state_dict(t)
            gen = gen.eval().to(DEVICE)
            gen = gen.decoder
            img_size = 128
        else:
            raise NotImplementedError()
        return gen, img_size
예제 #2
0
 def reset_gen(model):
     if model == 'began':
         gen = Generator128(64)
         gen = load_trained_net(
             gen,
             ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
              '.n_cuts=0/gen_ckpt.49.pt'))
         gen = gen.eval().to(DEVICE)
         img_size = 128
     elif model == 'vae':
         gen = VAE()
         t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt')
         gen.load_state_dict(t)
         gen = gen.eval().to(DEVICE)
         gen = gen.decoder
         img_size = 128
     elif model == 'dcgan':
         gen = dcgan_generator()
         t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64'
                         '.b1_0.5.lr_0.0002.pt'))
         gen.load_state_dict(t)
         gen = gen.eval().to(DEVICE)
         img_size = 64
     return gen, img_size
예제 #3
0
def cut_training(n_cols):
    began_settings = {
        1: {
            'batch_size':
            32,
            'z_lr':
            3e-5,
            'path':
            ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
             '.n_cuts=1.z_lr=3e-5/gen_ckpt.24.pt')
        },
        2: {
            'batch_size':
            32,
            'z_lr':
            8e-5,
            'path':
            ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
             '.n_cuts=1.z_lr=8e-5/gen_ckpt.19.pt')
        },
        3: {
            'batch_size':
            64,
            'z_lr':
            1e-4,
            'path':
            ('./checkpoints/celeba_began.withskips.bs64.cosine.min=0.25'
             '.n_cuts=1.z_lr=1e-4/gen_ckpt.19.pt')
        },
        4: {
            'batch_size':
            64,
            'z_lr':
            3e-5,
            'path':
            ('./checkpoints/celeba_began.withskips.bs64.cosine.min=0.25'
             '.n_cuts=1.z_lr=3e-5/gen_ckpt.24.pt')
        },
        5: {
            'batch_size':
            64,
            'z_lr':
            8e-5,
            'path':
            ('./checkpoints/celeba_began.withskips.bs64.cosine.min=0.25'
             '.n_cuts=1.z_lr=8e-5/gen_ckpt.24.pt')
        }
    }

    fig, ax = plt.subplots(len(began_settings.items()),
                           n_cols,
                           figsize=(n_cols, len(began_settings.items())))

    fig.suptitle('BEGAN (cuts=1)', fontsize=16)

    for i, settings in began_settings.items():
        g = Generator128(64).to('cuda')
        g = load_trained_net(g, settings['path'])

        input_shapes = g.input_shapes[1]
        z1_shape = input_shapes[0]
        z2_shape = input_shapes[1]

        for col in range(n_cols):
            z1 = torch.randn(1, *z1_shape).clamp(-1, 1).to('cuda')
            if len(z2_shape) == 0:
                z2 = None
            else:
                z2 = torch.randn(1, *z2_shape).clamp(-1, 1).to('cuda')
            img = g.forward(
                z1, z2, n_cuts=1).detach().cpu().squeeze(0).numpy().transpose(
                    [1, 2, 0])
            ax[i - 1, col].imshow(np.clip(img, 0, 1), aspect='auto')
            ax[i - 1, col].set_xticks([])
            ax[i - 1, col].set_yticks([])
            ax[i - 1, col].set_frame_on(False)

    fig.subplots_adjust(0, 0, 1, 0.93, 0, 0)

    os.makedirs('./figures/cut_training/', exist_ok=True)
    plt.savefig(f'./figures/cut_training/began_cut_training.pdf',
                bbox_inches='tight',
                dpi=300)

    dcgan_settings = {
        1: {
            'z_lr':
            5e-5,
            'b1':
            0.5,
            'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1'
                     '.bs_64.b1_0.5.lr_5e-05.pt')
        },
        2: {
            'z_lr':
            1e-4,
            'b1':
            0.5,
            'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1'
                     '.bs_64.b1_0.5.lr_0.0001.pt')
        },
        3: {
            'z_lr':
            2e-4,
            'b1':
            0.5,
            'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1'
                     '.bs_64.b1_0.5.lr_0.0002.pt')
        },
        4: {
            'z_lr':
            5e-5,
            'b1':
            0.9,
            'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1'
                     '.bs_64.b1_0.9.lr_5e-05.pt')
        },
        5: {
            'z_lr':
            2e-4,
            'b1':
            0.9,
            'path': ('./dcgan_checkpoints/netG.epoch_24.n_cuts_1'
                     '.bs_64.b1_0.9.lr_0.0002.pt')
        },
    }

    fig, ax = plt.subplots(len(dcgan_settings.items()),
                           n_cols,
                           figsize=(n_cols, len(dcgan_settings.items())))

    fig.suptitle('DCGAN (cuts=1)', fontsize=16)

    for i, settings in dcgan_settings.items():
        g = dcgan_generator().to('cuda')
        g.load_state_dict(torch.load(settings['path']))

        input_shapes = g.input_shapes[1]
        z1_shape = input_shapes[0]
        z2_shape = input_shapes[1]

        for col in range(n_cols):
            z1 = torch.randn(1, *z1_shape).clamp(-1, 1).to('cuda')
            if len(z2_shape) == 0:
                z2 = None
            else:
                z2 = torch.randn(1, *z2_shape).clamp(-1, 1).to('cuda')
            img = g.forward(
                z1, z2, n_cuts=1).detach().cpu().squeeze(0).numpy().transpose(
                    [1, 2, 0])
            # Rescale from [-1, 1] to [0, 1]
            img = (img + 1) / 2
            ax[i - 1, col].imshow(np.clip(img, 0, 1), aspect='auto')
            ax[i - 1, col].set_xticks([])
            ax[i - 1, col].set_yticks([])
            ax[i - 1, col].set_frame_on(False)

    fig.subplots_adjust(0, 0, 1, 0.93, 0, 0)

    os.makedirs('./figures/cut_training/', exist_ok=True)
    plt.savefig(f'./figures/cut_training/dcgan_cut_training.pdf',
                bbox_inches='tight',
                dpi=300)
예제 #4
0
def generator_samples(model):
    if model == 'began':
        g = Generator128(64).to('cuda:0')
        g = load_trained_net(
            g, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
                '.n_cuts=0/gen_ckpt.49.pt'))
    elif model == 'vae':
        g = VAE().to('cuda:0')
        g.load_state_dict(
            torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt'))
        g = g.decoder
    elif model == 'biggan':
        g = BigGanSkip().to('cuda:0')
    elif model == 'dcgan':
        g = dcgan_generator().to('cuda:0')
        g.load_state_dict(
            torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64'
                        '.b1_0.5.lr_0.0002.pt')))
    else:
        raise NotImplementedError

    nseed = 10
    n_cuts_list = [0, 1, 2, 3, 4, 5]

    fig, ax = plt.subplots(len(n_cuts_list),
                           nseed,
                           figsize=(10, len(n_cuts_list)))
    for row, n_cuts in enumerate(n_cuts_list):
        input_shapes = g.input_shapes[n_cuts]
        z1_shape = input_shapes[0]
        z2_shape = input_shapes[1]

        for col in range(nseed):
            torch.manual_seed(col)
            np.random.seed(col)
            if n_cuts == 0 and model == 'biggan':
                class_vector = torch.tensor(
                    949, dtype=torch.long).to('cuda:0').unsqueeze(
                        0)  # 949 = strawberry
                embed = g.biggan.embeddings(
                    torch.nn.functional.one_hot(
                        class_vector, num_classes=1000).to(torch.float))
                cond_vector = torch.cat(
                    (torch.randn(1, 128).to('cuda:0'), embed), dim=1)

                img = orig_biggan_forward(
                    g.biggan.generator, cond_vector,
                    truncation=1.0).detach().cpu().squeeze(
                        0).numpy().transpose([1, 2, 0])
            elif n_cuts > 0 and model == 'biggan':
                z1 = torch.randn(1, *z1_shape).to('cuda:0')

                class_vector = torch.tensor(
                    949, dtype=torch.long).to('cuda:0').unsqueeze(
                        0)  # 949 = strawberry
                embed = g.biggan.embeddings(
                    torch.nn.functional.one_hot(
                        class_vector, num_classes=1000).to(torch.float))
                cond_vector = torch.cat(
                    (torch.randn(1, 128).to('cuda:0'), embed), dim=1)
                z2 = cond_vector

                img = g(
                    z1, z2, truncation=1.0,
                    n_cuts=n_cuts).detach().cpu().squeeze(0).numpy().transpose(
                        [1, 2, 0])
            else:
                z1 = torch.randn(1, *z1_shape).to('cuda:0')
                if len(z2_shape) == 0:
                    z2 = None
                else:
                    z2 = torch.randn(1, *z2_shape).to('cuda:0')

                img = g(
                    z1, z2,
                    n_cuts=n_cuts).detach().cpu().squeeze(0).numpy().transpose(
                        [1, 2, 0])

            if g.rescale:
                img = (img + 1) / 2

            ax[row, col].imshow(np.clip(img, 0, 1), aspect='auto')
            ax[row, col].set_xticks([])
            ax[row, col].set_yticks([])
            ax[row, col].set_frame_on(False)
            if col == 0:
                ax[row, col].set_ylabel(f'{n_cuts}')

    fig.subplots_adjust(0, 0, 1, 1, 0, 0)
    os.makedirs('./figures/generator_samples', exist_ok=True)
    plt.savefig((f'./figures/generator_samples/'
                 f'model={model}.pdf'),
                dpi=300,
                bbox_inches='tight')
예제 #5
0
def main(args):
    checkpoint_path = f"checkpoints/{args.dataset}_{args.run_name}"
    tensorboard_path = f"tensorboard_logs/{args.dataset}_{args.run_name}"
    torch.backends.cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    writer = SummaryWriter(tensorboard_path)

    dataloader, _ = get_dataloader(args.dataset_dir, args.batch_size,
                                   args.n_train, True)

    gen = Generator128(args.latent_dim).to(device)
    disc = Discriminator128(args.latent_dim).to(device)

    # Get latent_shape for x1 only
    latent_shape = gen.input_shapes[args.n_cuts][0]

    if torch.cuda.device_count() > 1:
        gen = torch.nn.DataParallel(gen)
        disc = torch.nn.DataParallel(disc)

    gen_optimizer = torch.optim.Adam(gen.parameters(), args.lr)
    gen_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        gen_optimizer,
        len(dataloader) * args.epochs, 0.25 * args.lr)
    disc_optimizer = torch.optim.Adam(disc.parameters(), args.lr)
    disc_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        disc_optimizer,
        len(dataloader) * args.epochs, 0.25 * args.lr)

    current_checkpoint = 0
    if (not os.path.exists(checkpoint_path)):
        os.makedirs(checkpoint_path)
    else:
        print("Restoring from checkpoint...")
        paths = os.listdir(checkpoint_path)
        try:
            available = sorted(set([int(x.split(".")[1]) for x in paths]))

            # Find a checkpoint that both gen AND disc have reached
            # Reaching zero will cause IndexError during pop()
            while True:
                latest_idx = available.pop()
                latest_disc = os.path.join(checkpoint_path,
                                           f"disc_ckpt.{latest_idx}.pt")
                latest_gen = os.path.join(checkpoint_path,
                                          f"gen_ckpt.{latest_idx}.pt")
                if os.path.exists(latest_disc) and os.path.exists(latest_gen):
                    break

            current_checkpoint = latest_idx
            disc_epoch = load(latest_disc, disc, disc_optimizer,
                              disc_scheduler)
            gen_epoch = load(latest_gen, gen, gen_optimizer, gen_scheduler)
            assert disc_epoch == gen_epoch, \
                'Checkpoint contents are mismatched!'
            print(f"Loaded checkpoint {current_checkpoint}")
        except Exception as e:
            print(e)
            print("Unable to load from checkpoint.")

    k = 0

    # Uniform from -1 to 1
    const_sample = get_z_vector((args.batch_size, *latent_shape),
                                mode='uniform',
                                dtype=torch.float,
                                device=device)

    n_gen_param = sum([x.numel() for x in gen.parameters() if x.requires_grad])
    n_disc_param = sum(
        [x.numel() for x in disc.parameters() if x.requires_grad])
    print(f"{n_gen_param + n_disc_param} Trainable Parameters")

    if current_checkpoint < args.epochs - 1:
        for e in trange(current_checkpoint,
                        args.epochs,
                        initial=current_checkpoint,
                        desc='Epoch',
                        leave=True,
                        disable=args.disable_tqdm):
            for i, img_batch in tqdm(enumerate(dataloader),
                                     total=len(dataloader),
                                     leave=False,
                                     disable=args.disable_tqdm):
                disc_optimizer.zero_grad()
                gen_optimizer.zero_grad()

                img_batch = img_batch.to(device)

                # Uniform from -1 to 1
                d_latent_sample = get_z_vector(
                    (args.batch_size, *latent_shape),
                    mode='uniform',
                    dtype=torch.float,
                    device=device)

                g_latent_sample = get_z_vector(
                    (args.batch_size, *latent_shape),
                    mode='uniform',
                    dtype=torch.float,
                    device=device)

                batch_ac_loss = ac_loss(img_batch, disc)
                d_fake_ac_loss = ac_loss(
                    gen.forward(d_latent_sample, x2=None,
                                n_cuts=args.n_cuts).detach(), disc)
                g_fake_ac_loss = ac_loss(
                    gen.forward(g_latent_sample, x2=None, n_cuts=args.n_cuts),
                    disc)

                def d_loss():
                    loss = batch_ac_loss - k * d_fake_ac_loss
                    loss.backward()
                    return loss

                def g_loss():
                    loss = g_fake_ac_loss
                    loss.backward()
                    return loss

                disc_optimizer.step(d_loss)
                gen_optimizer.step(g_loss)
                disc_scheduler.step()
                gen_scheduler.step()

                k = k + args.prop_gain * \
                    (args.gamma * batch_ac_loss.item() - g_fake_ac_loss.item())

                m = ac_loss(img_batch, disc) + \
                    torch.abs(args.gamma * batch_ac_loss - g_fake_ac_loss)
                writer.add_scalar("Convergence", m, len(dataloader) * e + i)

                if (i % args.log_every == 0):
                    ex_img = gen.forward(g_latent_sample,
                                         x2=None,
                                         n_cuts=args.n_cuts)[0]
                    writer.add_image("Random/Raw", ex_img,
                                     len(dataloader) * e + i)
                    writer.add_image("Random/Clamp", ex_img.clamp(0, 1),
                                     len(dataloader) * e + i)
                    writer.add_image("Random/Normalize", normalize(ex_img),
                                     len(dataloader) * e + i)
                    ex_img_const = gen.forward(const_sample,
                                               x2=None,
                                               n_cuts=args.n_cuts)[0]
                    writer.add_image("Constant/Raw", ex_img_const,
                                     len(dataloader) * e + i)
                    writer.add_image("Constant/Clamp",
                                     ex_img_const.clamp(0, 1),
                                     len(dataloader) * e + i)
                    writer.add_image("Constant/Normalize",
                                     normalize(ex_img_const),
                                     len(dataloader) * e + i)

            save(os.path.join(checkpoint_path, f"gen_ckpt.{e}.pt"), e, gen,
                 gen_optimizer, gen_scheduler)
            save(os.path.join(checkpoint_path, f"disc_ckpt.{e}.pt"), e, disc,
                 disc_optimizer, disc_scheduler)
예제 #6
0
    lasso_est = Lasso(alpha=gamma)
    lasso_est.fit(A_val.T, y_val.reshape(-1))
    x_hat = lasso_est.coef_
    x_hat = np.reshape(x_hat, [-1])
    return x_hat


if __name__ == '__main__':
    DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    a = argparse.ArgumentParser()
    a.add_argument('--img_dir', required=True)
    a.add_argument('--disable_tqdm', default=False)
    args = a.parse_args()

    gen = Generator128(64)
    gen = load_trained_net(
        gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
              '.n_cuts=0/gen_ckpt.49.pt'))
    gen = gen.eval().to(DEVICE)

    n_cuts = 3

    img_size = 128
    img_shape = (3, img_size, img_size)

    forward_model = GaussianCompressiveSensing(n_measure=2500,
                                               img_shape=img_shape)
    # forward_model = NoOp()

    for img_name in tqdm(os.listdir(args.img_dir),
예제 #7
0
def mgan_images(args):
    if args.set_seed:
        torch.manual_seed(0)
        np.random.seed(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    os.makedirs(BASE_DIR, exist_ok=True)

    if args.model in ['mgan_began_cs']:
        gen = Generator128(64)
        gen = load_trained_net(
            gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
                  '.n_cuts=0/gen_ckpt.49.pt'))
        gen = gen.eval().to(DEVICE)
        img_size = 128
    elif args.model in ['mgan_vanilla_vae_cs']:
        gen = VAE()
        t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt')
        gen.load_state_dict(t)
        gen = gen.eval().to(DEVICE)
        gen = gen.decoder
        img_size = 128
    elif args.model in ['mgan_dcgan_cs']:
        gen = dcgan_generator()
        t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64'
                        '.b1_0.5.lr_0.0002.pt'))
        gen.load_state_dict(t)
        gen = gen.eval().to(DEVICE)
        img_size = 64
    else:
        raise NotImplementedError()

    img_shape = (3, img_size, img_size)
    metadata = recovery_settings[args.model]
    n_cuts_list = metadata['n_cuts_list']
    del (metadata['n_cuts_list'])

    z_init_mode_list = metadata['z_init_mode']
    limit_list = metadata['limit']
    assert len(z_init_mode_list) == len(limit_list)
    del (metadata['z_init_mode'])
    del (metadata['limit'])

    forwards = forward_models[args.model]

    data_split = Path(args.img_dir).name
    for img_name in tqdm(sorted(os.listdir(args.img_dir)),
                         desc='Images',
                         leave=True,
                         disable=args.disable_tqdm):
        # Load image and get filename without extension
        orig_img = load_target_image(os.path.join(args.img_dir, img_name),
                                     img_size).to(DEVICE)
        img_basename, _ = os.path.splitext(img_name)

        for n_cuts in tqdm(n_cuts_list,
                           desc='N_cuts',
                           leave=False,
                           disable=args.disable_tqdm):
            metadata['n_cuts'] = n_cuts
            for i, (f, f_args_list) in enumerate(
                    tqdm(forwards.items(),
                         desc='Forwards',
                         leave=False,
                         disable=args.disable_tqdm)):
                for f_args in tqdm(f_args_list,
                                   desc=f'{f} Args',
                                   leave=False,
                                   disable=args.disable_tqdm):

                    f_args['img_shape'] = img_shape
                    forward_model = get_forward_model(f, **f_args)

                    for z_init_mode, limit in zip(
                            tqdm(z_init_mode_list,
                                 desc='z_init_mode',
                                 leave=False), limit_list):
                        metadata['z_init_mode'] = z_init_mode
                        metadata['limit'] = limit

                        # Before doing recovery, check if results already exist
                        # and possibly skip
                        recovered_name = 'recovered.pt'
                        results_folder = get_results_folder(
                            image_name=img_basename,
                            model=args.model,
                            n_cuts=n_cuts,
                            split=data_split,
                            forward_model=forward_model,
                            recovery_params=dict_to_str(metadata),
                            base_dir=BASE_DIR)

                        os.makedirs(results_folder, exist_ok=True)

                        recovered_path = results_folder / recovered_name
                        if os.path.exists(
                                recovered_path) and not args.overwrite:
                            print(
                                f'{recovered_path} already exists, skipping...'
                            )
                            continue

                        if args.run_name is not None:
                            current_run_name = (
                                f'{img_basename}.{forward_model}'
                                f'.{dict_to_str(metadata)}'
                                f'.{args.run_name}')
                        else:
                            current_run_name = None

                        recovered_img, distorted_img, _ = mgan_recover(
                            orig_img, gen, n_cuts, forward_model,
                            metadata['optimizer'], z_init_mode, limit,
                            metadata['z_lr'], metadata['n_steps'],
                            metadata['z_number'], metadata['restarts'],
                            args.run_dir, current_run_name, args.disable_tqdm)

                        # Make images folder
                        img_folder = get_images_folder(split=data_split,
                                                       image_name=img_basename,
                                                       img_size=img_size,
                                                       base_dir=BASE_DIR)
                        os.makedirs(img_folder, exist_ok=True)

                        # Save original image if needed
                        original_img_path = img_folder / 'original.pt'
                        if not os.path.exists(original_img_path):
                            torch.save(orig_img, original_img_path)

                        # Save distorted image if needed
                        if forward_model.viewable:
                            distorted_img_path = img_folder / f'{forward_model}.pt'
                            if not os.path.exists(distorted_img_path):
                                torch.save(distorted_img, distorted_img_path)

                        # Save recovered image and metadata
                        torch.save(recovered_img, recovered_path)
                        pickle.dump(
                            metadata,
                            open(results_folder / 'metadata.pkl', 'wb'))
                        p = psnr(recovered_img, orig_img)
                        pickle.dump(p, open(results_folder / 'psnr.pkl', 'wb'))