Exemple #1
0
def iagan_recover(x,
                  gen,
                  forward_model,
                  optimizer_type,
                  mode='clamped_normal',
                  limit=1,
                  z_lr1=1e-4,
                  z_lr2=1e-5,
                  model_lr=1e-5,
                  z_steps1=1600,
                  z_steps2=3000,
                  restarts=1,
                  run_dir=None,
                  run_name=None,
                  disable_tqdm=False,
                  **kwargs):

    best_psnr = -float("inf")
    best_return_val = None

    for i in trange(restarts,
                    desc='Restarts',
                    leave=False,
                    disable=disable_tqdm):
        if run_name is not None:
            current_run_name = f'{run_name}_{i}'
        else:
            current_run_name = None
        return_val = _iagan_recover(x=x,
                                    gen=gen,
                                    forward_model=forward_model,
                                    optimizer_type=optimizer_type,
                                    mode=mode,
                                    limit=limit,
                                    z_lr1=z_lr1,
                                    z_lr2=z_lr2,
                                    model_lr=model_lr,
                                    z_steps1=z_steps1,
                                    z_steps2=z_steps2,
                                    run_dir=run_dir,
                                    run_name=current_run_name,
                                    disable_tqdm=disable_tqdm,
                                    **kwargs)
        p = psnr_from_mse(return_val[2])
        if p > best_psnr:
            best_psnr = p
            best_return_val = return_val

    return best_return_val
Exemple #2
0
def recover(x,
            gen,
            optimizer_type,
            n_cuts,
            forward_model,
            mode='clamped_normal',
            limit=1,
            z_lr=0.5,
            n_steps=2000,
            restarts=1,
            run_dir=None,
            run_name=None,
            disable_tqdm=False,
            return_z1_z2=False,
            **kwargs):

    best_psnr = -float("inf")
    best_return_val = None

    for i in trange(restarts,
                    desc='Restarts',
                    leave=False,
                    disable=disable_tqdm):
        if run_name is not None:
            current_run_name = f'{run_name}_{i}'
        else:
            current_run_name = None
        return_val = _recover(x=x,
                              gen=gen,
                              optimizer_type=optimizer_type,
                              n_cuts=n_cuts,
                              forward_model=forward_model,
                              mode=mode,
                              limit=limit,
                              z_lr=z_lr,
                              n_steps=n_steps,
                              run_dir=run_dir,
                              run_name=current_run_name,
                              disable_tqdm=disable_tqdm,
                              return_z1_z2=return_z1_z2,
                              **kwargs)
        p = psnr_from_mse(return_val[2])
        if p > best_psnr:
            best_psnr = p
            best_return_val = return_val

    return best_return_val
def deep_decoder_recover(
        x,
        forward_model,
        optimizer='lbfgs',
        num_filters=64,
        depth=6,  # TODO
        lr=1,
        img_size=64,
        steps=50,
        restarts=1,
        run_dir=None,
        run_name=None,
        disable_tqdm=False,
        **kwargs):

    best_psnr = -float("inf")
    best_return_val = None

    for i in trange(restarts,
                    desc='Restarts',
                    leave=False,
                    disable=disable_tqdm):
        if run_name is not None:
            current_run_name = f'{run_name}_{i}'
        else:
            current_run_name = None
        return_val = _deep_decoder_recover(x=x,
                                           forward_model=forward_model,
                                           optimizer=optimizer,
                                           num_filters=num_filters,
                                           depth=depth,
                                           lr=lr,
                                           img_size=img_size,
                                           steps=steps,
                                           run_dir=run_dir,
                                           run_name=current_run_name,
                                           disable_tqdm=disable_tqdm,
                                           **kwargs)
        p = psnr_from_mse(return_val[2])
        if p > best_psnr:
            best_psnr = p
            best_return_val = return_val

    return best_return_val
Exemple #4
0
def _recover(x,
             gen,
             optimizer_type,
             n_cuts,
             forward_model,
             mode='clamped_normal',
             limit=1,
             z_lr=0.5,
             n_steps=2000,
             run_dir=None,
             run_name=None,
             disable_tqdm=False,
             return_z1_z2=False,
             **kwargs):
    """
    Args:
        x - input image, torch tensor (C x H x W)
        gen - generator, already loaded with checkpoint weights
        forward_model - corrupts the image
        n_steps - number of optimization steps during recovery
        run_name - use None for no logging
    """

    # Keep batch_size = 1
    batch_size = 1

    z1_dim, z2_dim = gen.input_shapes[n_cuts]

    if (isinstance(forward_model, GaussianCompressiveSensing)):
        n_pixel_bora = 64 * 64 * 3
        n_pixel = np.prod(x.shape)
        noise = torch.randn(batch_size,
                            forward_model.n_measure,
                            device=x.device)
        noise *= 0.1 * torch.sqrt(
            torch.tensor(n_pixel / forward_model.n_measure / n_pixel_bora))

    if mode == 'lasso_inverse' and isinstance(forward_model,
                                              GaussianCompressiveSensing):
        lasso_x_hat = recover_dct(x.cpu().numpy().transpose([1, 2, 0]),
                                  forward_model.n_measure,
                                  0.01,
                                  128,
                                  A=forward_model.A.cpu().numpy(),
                                  noise=noise.cpu().numpy())

        _, _, _, z1_z2_dict = recover(torch.tensor(
            lasso_x_hat.transpose([2, 0, 1]), dtype=torch.float).to(DEVICE),
                                      gen,
                                      optimizer_type=optimizer_type,
                                      n_cuts=n_cuts,
                                      forward_model=forward_model,
                                      mode='clamped_normal',
                                      limit=limit,
                                      z_lr=z_lr,
                                      n_steps=n_steps,
                                      restarts=1,
                                      return_z1_z2=True)
        z1 = torch.nn.Parameter(z1_z2_dict['z1'])
        params = [z1]
        if len(z2_dim) > 0:
            z2 = torch.nn.Parameter(z1_z2_dict['z2'])
            params.append(z2)
        else:
            z2 = None

    else:
        z1 = torch.nn.Parameter(
            get_z_vector((batch_size, *z1_dim),
                         mode=mode,
                         limit=limit,
                         device=x.device))
        # print('z1: ', z1.min(), z1.max())
        params = [z1]
        if len(z2_dim) > 0:
            z2 = torch.nn.Parameter(
                get_z_vector((batch_size, *z2_dim),
                             mode=mode,
                             limit=limit,
                             device=x.device))
            # print('z2: ', z2.min(), z2.max())
            params.append(z2)
        else:
            z2 = None

    if optimizer_type == 'adamw':
        optimizer_z = torch.optim.AdamW(params,
                                        lr=z_lr,
                                        betas=(0.5, 0.999),
                                        weight_decay=0)
        scheduler_z = None
        save_img_every_n = 50
    elif optimizer_type == 'lbfgs':
        optimizer_z = torch.optim.LBFGS(params, lr=z_lr)
        scheduler_z = None
        save_img_every_n = 2
    else:
        raise NotImplementedError()

    if run_name is not None:
        logdir = os.path.join('recovery_tensorboard_logs', run_dir, run_name)
        if os.path.exists(logdir):
            print("Overwriting pre-existing logs!")
            shutil.rmtree(logdir)
        writer = SummaryWriter(logdir)

    # Save original and distorted image
    if run_name is not None:
        writer.add_image("Original/Clamp", x.clamp(0, 1))
        if forward_model.viewable:
            writer.add_image(
                "Distorted/Clamp",
                forward_model(x.unsqueeze(0).clamp(0, 1)).squeeze(0))

    # Recover image under forward model
    x = x.expand(batch_size, *x.shape)

    y_observed = forward_model(x)
    if (isinstance(forward_model, GaussianCompressiveSensing)):
        y_observed += noise

    for j in trange(n_steps,
                    leave=False,
                    desc='Recovery',
                    disable=disable_tqdm):

        def closure():
            optimizer_z.zero_grad()
            x_hats = gen.forward(z1, z2, n_cuts=n_cuts, **kwargs)
            if gen.rescale:
                x_hats = (x_hats + 1) / 2
            train_mses = F.mse_loss(forward_model(x_hats),
                                    y_observed,
                                    reduction='none')
            train_mses = train_mses.view(batch_size, -1).mean(1)

            train_mse = train_mses.sum()
            train_mse.backward()
            return train_mse

        # Step first, then identify the current "best" and "worst"
        optimizer_z.step(closure)
        with torch.no_grad():
            x_hats = gen.forward(z1, z2, n_cuts=n_cuts, **kwargs)
            if gen.rescale:
                x_hats = (x_hats + 1) / 2
            train_mses = F.mse_loss(forward_model(x_hats),
                                    y_observed,
                                    reduction='none')
            train_mses = train_mses.view(batch_size, -1).mean(1)
            train_mse = train_mses.sum()

        train_mses_clamped = F.mse_loss(forward_model(x_hats.detach().clamp(
            0, 1)),
                                        y_observed,
                                        reduction='none').view(batch_size,
                                                               -1).mean(1)
        orig_mses_clamped = F.mse_loss(x_hats.detach().clamp(0, 1),
                                       x,
                                       reduction='none').view(batch_size,
                                                              -1).mean(1)

        # batch_size = 1, so best and worst are meaningless.
        # Restarts is handled in outer function
        best_train_mse, best_idx = train_mses_clamped.min(0)
        worst_train_mse, worst_idx = train_mses_clamped.max(0)
        best_orig_mse = orig_mses_clamped[best_idx]
        worst_orig_mse = orig_mses_clamped[worst_idx]

        if run_name is not None and j == 0:
            writer.add_image('Start', x_hats[best_idx].clamp(0, 1))

        if run_name is not None:
            writer.add_scalar('TRAIN_MSE/best', best_train_mse, j + 1)
            writer.add_scalar('TRAIN_MSE/worst', worst_train_mse, j + 1)
            writer.add_scalar('TRAIN_MSE/sum', train_mse, j + 1)
            writer.add_scalar('ORIG_MSE/best', best_orig_mse, j + 1)
            writer.add_scalar('ORIG_MSE/worst', worst_orig_mse, j + 1)
            writer.add_scalar('ORIG_PSNR/best', psnr_from_mse(best_orig_mse),
                              j + 1)
            writer.add_scalar('ORIG_PSNR/worst', psnr_from_mse(worst_orig_mse),
                              j + 1)

            if j % save_img_every_n == 0:
                writer.add_image('Recovered/Best',
                                 x_hats[best_idx].clamp(0, 1), j + 1)

        if scheduler_z is not None:
            scheduler_z.step()

    if run_name is not None:
        writer.add_image('Final', x_hats[best_idx].clamp(0, 1))

    if return_z1_z2:
        return x_hats[best_idx], forward_model(x)[0], best_train_mse, {
            'z1': z1,
            'z2': z2
        }
    else:
        return x_hats[best_idx], forward_model(x)[0], best_train_mse
def _deep_decoder_recover(
    x,
    forward_model,
    optimizer,
    num_filters,
    depth,
    lr,
    img_size,
    steps,
    run_dir,
    run_name,
    disable_tqdm,
    **kwargs,
):
    # Keep batch_size = 1
    batch_size = 1

    if (isinstance(forward_model, GaussianCompressiveSensing)):
        n_pixel_bora = 64 * 64 * 3
        n_pixel = np.prod(x.shape)
        noise = torch.randn(batch_size,
                            forward_model.n_measure,
                            device=x.device)
        noise *= 0.1 * torch.sqrt(
            torch.tensor(n_pixel / forward_model.n_measure / n_pixel_bora))

    # z is a fixed latent vector
    start_imsize = int(np.log2(img_size)) - depth + 1
    z = torch.randn(batch_size,
                    num_filters,
                    start_imsize,
                    start_imsize,
                    device=x.device)

    # make a fresh DD model for every run
    model = DeepDecoder(num_filters=num_filters,
                        img_size=img_size,
                        depth=depth).to(x.device)

    if optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        save_img_every_n = 50
    elif optimizer == 'lbfgs':
        optimizer = torch.optim.LBFGS(model.parameters(), lr=lr)
        save_img_every_n = 2
    else:
        raise NotImplementedError()

    if run_name is not None:
        logdir = os.path.join('recovery_tensorboard_logs', run_dir, run_name)
        if os.path.exists(logdir):
            print("Overwriting pre-existing logs!")
            shutil.rmtree(logdir)
        writer = SummaryWriter(logdir)
    else:
        writer = None

    # Save original and distorted image
    if run_name is not None:
        writer.add_image("Original/Clamp", x.clamp(0, 1))
        if forward_model.viewable:
            writer.add_image(
                "Distorted/Clamp",
                forward_model(x.unsqueeze(0).clamp(0, 1)).squeeze(0))

    # Make noisy gaussian measurements
    x = x.expand(batch_size, *x.shape)
    y_observed = forward_model(x)
    if (isinstance(forward_model, GaussianCompressiveSensing)):
        y_observed += noise

    def closure():
        optimizer.zero_grad()
        x_hat = model.forward(z)
        loss = F.mse_loss(forward_model(x_hat), y_observed)
        loss.backward()
        return loss

    for j in trange(steps, desc='Fit', leave=False):
        optimizer.step(closure)
        with torch.no_grad():
            x_hat = model.forward(z)

        train_mse_clamped = F.mse_loss(
            forward_model(x_hat.detach().clamp(0, 1)), y_observed)
        if writer is not None:
            writer.add_scalar('TRAIN_MSE', train_mse_clamped, j + 1)
            writer.add_scalar('TRAIN_PSNR', psnr_from_mse(train_mse_clamped),
                              j + 1)

            orig_mse_clamped = F.mse_loss(x_hat.detach().clamp(0, 1), x)
            writer.add_scalar('ORIG_MSE', orig_mse_clamped, j + 1)
            writer.add_scalar('ORIG_PSNR', psnr_from_mse(orig_mse_clamped),
                              j + 1)
            if j % save_img_every_n == 0:
                writer.add_image('Recovered',
                                 x_hat.squeeze().clamp(0, 1), j + 1)

    if writer is not None:
        writer.add_image('Final', x_hat.squeeze().clamp(0, 1))

    return x_hat.squeeze(), forward_model(x).squeeze(), train_mse_clamped
Exemple #6
0
def _iagan_recover(
        x,
        gen,
        forward_model,
        optimizer_type='adam',
        mode='clamped_normal',
        limit=1,
        z_lr1=1e-4,
        z_lr2=1e-5,
        model_lr=1e-5,
        z_steps1=1600,
        z_steps2=3000,
        run_dir=None,  # IAGAN
        run_name=None,  # datetime or config
        disable_tqdm=False,
        **kwargs):

    # Keep batch_size = 1
    batch_size = 1

    z1_dim, z2_dim = gen.input_shapes[0]  # n_cuts = 0

    if (isinstance(forward_model, GaussianCompressiveSensing)):
        n_pixel_bora = 64 * 64 * 3
        n_pixel = np.prod(x.shape)
        noise = torch.randn(batch_size,
                            forward_model.n_measure,
                            device=x.device)
        noise *= 0.1 * torch.sqrt(
            torch.tensor(n_pixel / forward_model.n_measure / n_pixel_bora))

    # z1 is the actual latent code.
    # z2 is the additional input for n_cuts logic (not used here)
    z1 = torch.nn.Parameter(
        get_z_vector((batch_size, *z1_dim),
                     mode=mode,
                     limit=limit,
                     device=x.device))
    params = [z1]
    if len(z2_dim) > 0:
        z2 = torch.nn.Parameter(
            get_z_vector((batch_size, *z2_dim),
                         mode=mode,
                         limit=limit,
                         device=x.device))
        params.append(z2)
    else:
        z2 = None

    if optimizer_type == 'adam':
        optimizer_z = torch.optim.Adam([z1], lr=z_lr1)
        optimizer_model = torch.optim.Adam(gen.parameters(), lr=model_lr)
    else:
        raise NotImplementedError()

    if run_name is not None:
        logdir = os.path.join('recovery_tensorboard_logs', run_dir, run_name)
        if os.path.exists(logdir):
            print("Overwriting pre-existing logs!")
            shutil.rmtree(logdir)
        writer = SummaryWriter(logdir)

    # Save original and distorted image
    if run_name is not None:
        writer.add_image("Original/Clamp", x.clamp(0, 1))
        if forward_model.viewable:
            writer.add_image(
                "Distorted/Clamp",
                forward_model(x.unsqueeze(0).clamp(0, 1)).squeeze(0))

    # Make noisy gaussian measurements
    x = x.expand(batch_size, *x.shape)
    y_observed = forward_model(x)
    if (isinstance(forward_model, GaussianCompressiveSensing)):
        y_observed += noise

    # Stage 1: optimize latent code only
    save_img_every_n = 50
    for j in trange(z_steps1, desc='Stage1', leave=False):
        optimizer_z.zero_grad()
        x_hat = gen.forward(z1, z2, n_cuts=0, **kwargs)
        if gen.rescale:
            x_hat = (x_hat + 1) / 2
        train_mse = F.mse_loss(forward_model(x_hat), y_observed)
        train_mse.backward()
        optimizer_z.step()

        train_mse_clamped = F.mse_loss(
            forward_model(x_hat.detach().clamp(0, 1)), y_observed)

        orig_mse_clamped = F.mse_loss(x_hat.detach().clamp(0, 1), x)

        if run_name is not None and j == 0:
            writer.add_image('Stage1/Start', x_hat.squeeze().clamp(0, 1))

        if run_name is not None:
            writer.add_scalar('Stage1/TRAIN_MSE', train_mse_clamped, j + 1)
            writer.add_scalar('Stage1/ORIG_MSE', orig_mse_clamped, j + 1)
            writer.add_scalar('Stage1/ORIG_PSNR',
                              psnr_from_mse(orig_mse_clamped), j + 1)

            if j % save_img_every_n == 0:
                writer.add_image('Stage1/Recovered',
                                 x_hat.squeeze().clamp(0, 1), j + 1)

    if run_name is not None:
        writer.add_image('Stage1_Final', x_hat.squeeze().clamp(0, 1))

    # Stage 2: optimize latent code and model
    save_img_every_n = 20
    optimizer_z = torch.optim.Adam([z1], lr=z_lr2)
    for j in trange(z_steps2, desc='Stage2', leave=False):
        optimizer_z.zero_grad()
        optimizer_model.zero_grad()
        x_hat = gen.forward(z1, z2, n_cuts=0, **kwargs)
        if gen.rescale:
            x_hat = (x_hat + 1) / 2
        train_mse = F.mse_loss(forward_model(x_hat), y_observed)
        train_mse.backward()
        optimizer_z.step()
        optimizer_model.step()

        train_mse_clamped = F.mse_loss(
            forward_model(x_hat.detach().clamp(0, 1)), y_observed)

        orig_mse_clamped = F.mse_loss(x_hat.detach().clamp(0, 1), x)

        if run_name is not None and j == 0:
            writer.add_image('Stage2/Start', x_hat.squeeze().clamp(0, 1))

        if run_name is not None:
            writer.add_scalar('Stage2/TRAIN_MSE', train_mse_clamped, j + 1)
            writer.add_scalar('Stage2/ORIG_MSE', orig_mse_clamped, j + 1)
            writer.add_scalar('Stage2/ORIG_PSNR',
                              psnr_from_mse(orig_mse_clamped), j + 1)

            if j % save_img_every_n == 0:
                writer.add_image('Stage2/Recovered',
                                 x_hat.squeeze().clamp(0, 1), j + 1)

    if run_name is not None:
        writer.add_image('Stage2_Final', x_hat.squeeze().clamp(0, 1))

    return x_hat.squeeze(), forward_model(x).squeeze(), train_mse_clamped
Exemple #7
0
def _mgan_recover(x,
                  gen,
                  n_cuts,
                  forward_model,
                  optimizer_type='sgd',
                  mode='zero',
                  limit=1,
                  z_lr=1,
                  n_steps=2000,
                  z_number=20,
                  run_dir=None,
                  run_name=None,
                  disable_tqdm=False,
                  **kwargs):
    """
    Args:
        x - input image, torch tensor (C x H x W)
        gen - generator, already loaded with checkpoint weights
        forward_model - corrupts the image
        n_cuts - the intermediate layer to combine z vectors
        n_steps - number of optimization steps during recovery
        run_name - use None for no logging
    """

    z1_dim, _ = gen.input_shapes[0]
    _, z2_dim = gen.input_shapes[n_cuts]

    if (isinstance(forward_model, GaussianCompressiveSensing)):
        n_pixel_bora = 64 * 64 * 3
        n_pixel = np.prod(x.shape)
        noise = torch.randn(1, forward_model.n_measure, device=x.device)
        noise *= 0.1 * torch.sqrt(torch.tensor(n_pixel / forward_model.n_measure / n_pixel_bora))

    z1 = torch.nn.Parameter(get_z_vector((z_number, *z1_dim), mode=mode, limit=limit, device=x.device))
    alpha = torch.nn.Parameter(
        get_z_vector((z_number, gen.input_shapes[n_cuts][0][0]), mode=mode, limit=limit, device=x.device))
    params = [z1, alpha]
    if len(z2_dim) > 0:
        z2 = torch.nn.Parameter(get_z_vector((1, *z2_dim), mode=mode, limit=limit, device=x.device))
        params.append(z2)
    else:
        z2 = None

    if optimizer_type == 'sgd':
        optimizer_z = torch.optim.SGD(params, lr=z_lr)
        scheduler_z = None
        save_img_every_n = 50
    elif optimizer_type == 'adam':
        optimizer_z = torch.optim.Adam(params, lr=z_lr)
        scheduler_z = None
        # scheduler_z = torch.optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer_z, n_steps, 0.05 * z_lr)
        save_img_every_n = 50
    else:
        raise NotImplementedError()

    if run_name is not None:
        logdir = os.path.join('recovery_tensorboard_logs', run_dir, run_name)
        if os.path.exists(logdir):
            print("Overwriting pre-existing logs!")
            shutil.rmtree(logdir)
        writer = SummaryWriter(logdir)

    # Save original and distorted image
    if run_name is not None:
        writer.add_image("Original/Clamp", x.clamp(0, 1))
        if forward_model.viewable:
            writer.add_image("Distorted/Clamp", forward_model(x.unsqueeze(0).clamp(0, 1)).squeeze(0))

    # Recover image under forward model
    x = x.expand(1, *x.shape)
    y_observed = forward_model(x)
    if (isinstance(forward_model, GaussianCompressiveSensing)):
        y_observed += noise

    for j in trange(n_steps, leave=False, desc='Recovery', disable=disable_tqdm):

        optimizer_z.zero_grad()
        F_l = gen.forward(z1, None, n_cuts=0, end=n_cuts, **kwargs)
        F_l_2 = (F_l * alpha[:, :, None, None]).sum(0, keepdim=True)
        x_hats = gen.forward(F_l_2, z2, n_cuts=n_cuts, end=None, **kwargs)
        if gen.rescale:
            x_hats = (x_hats + 1) / 2
        train_mse = F.mse_loss(forward_model(x_hats), y_observed)
        train_mse.backward()
        optimizer_z.step()

        train_mse_clamped = F.mse_loss(forward_model(x_hats.detach().clamp(0, 1)), y_observed)
        orig_mse_clamped = F.mse_loss(x_hats.detach().clamp(0, 1), x)

        if run_name is not None and j == 0:
            writer.add_image('Start', x_hats.clamp(0, 1).squeeze(0))

        if run_name is not None:
            writer.add_scalar('TRAIN_MSE', train_mse_clamped, j + 1)
            writer.add_scalar('ORIG_MSE', orig_mse_clamped, j + 1)
            writer.add_scalar('ORIG_PSNR', psnr_from_mse(orig_mse_clamped), j + 1)

            if j % save_img_every_n == 0:
                writer.add_image('Recovered', x_hats.clamp(0, 1).squeeze(0), j + 1)

        if scheduler_z is not None:
            scheduler_z.step()

    if run_name is not None:
        writer.add_image('Final', x_hats.clamp(0, 1).squeeze(0))

    return x_hats.squeeze(0), forward_model(x)[0], train_mse_clamped