示例#1
0
def main():
    """Make and save image matrices"""
    hparams = Hparams()
    xs_dict = celebA_input.model_input(hparams)
    start, stop = 20, 30
    images_nums = get_image_nums(start, stop, hparams)
    is_save = True
    for num_measurements in [50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]:
        pattern1 = './estimated/celebA/full-input/gaussian/0.01/' + str(
            num_measurements) + '/lasso-dct/0.1/{0}.png'
        pattern2 = './estimated/celebA/full-input/gaussian/0.01/' + str(
            num_measurements) + '/lasso-wavelet/1e-05/{0}.png'
        pattern3 = './estimated/celebA/full-input/gaussian/0.01/' + str(
            num_measurements
        ) + '/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_500_10/{0}.png'
        patterns = [pattern1, pattern2, pattern3]
        view(xs_dict, patterns, images_nums, hparams)
        base_path = './results/celebA_reconstr_{}_orig_lasso-dct_lasso-wavelet_dcgan.pdf'
        save_path = base_path.format(num_measurements)
        utils.save_plot(is_save, save_path)
def main(hparams):
    # set up perceptual loss
    device = 'cuda:0'
    percept = PerceptualLoss(
            model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
    )

    utils.print_hparams(hparams)

    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses, lpips_scores, z_hats = utils.load_checkpoints(hparams)

    x_hats_dict = {model_type : {} for model_type in hparams.model_types}
    x_batch_dict = {}

    A = utils.get_A(hparams)
    noise_batch = hparams.noise_std * np.random.standard_t(2, size=(hparams.batch_size, hparams.num_measurements))



    for key, x in xs_dict.items():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.items()]
        x_batch = np.concatenate(x_batch_list)

        # Construct noise and measurements


        y_batch = utils.get_measurements(x_batch, A, noise_batch, hparams)

        # Construct estimates using each estimator
        for model_type in hparams.model_types:
            estimator = estimators[model_type]
            x_hat_batch, z_hat_batch, m_loss_batch = estimator(A, y_batch, hparams)

            for i, key in enumerate(x_batch_dict.keys()):
                x = xs_dict[key]
                y_train = y_batch[i]
                x_hat = x_hat_batch[i]

                # Save the estimate
                x_hats_dict[model_type][key] = x_hat

                # Compute and store measurement and l2 loss
                measurement_losses[model_type][key] = m_loss_batch[key]
                l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x)
                lpips_scores[model_type][key] = utils.get_lpips_score(percept, x_hat, x, hparams.image_shape)
                z_hats[model_type][key] = z_hat_batch[i]

        print('Processed upto image {0} / {1}'.format(key+1, len(xs_dict)))

        # Checkpointing
        if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams)
            x_hats_dict = {model_type : {} for model_type in hparams.model_types}
            print('\nProcessed and saved first ', key+1, 'images\n')

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams)
        print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict)))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print(model_type)
            measurement_loss_list = list(measurement_losses[model_type].values())
            l2_loss_list = list(l2_losses[model_type].values())
            mean_m_loss = np.mean(measurement_loss_list)
            mean_l2_loss = np.mean(l2_loss_list)
            print('mean measurement loss = {0}'.format(mean_m_loss))
            print('mean l2 loss = {0}'.format(mean_l2_loss))

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print('\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict)))
        print('Consider rerunning lazily with a smaller batch size.')
示例#3
0
def main(hparams):
    # Set up some stuff according to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    maxiter = hparams.max_outer_iter
    utils.print_hparams(hparams)

    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    x_hats_dict = {'dcgan' : {}}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        if hparams.lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()]
        x_batch = np.concatenate(x_batch_list)

        # Construct measurements
        A_outer = utils.get_outer_A(hparams)

        y_batch_outer=np.matmul(x_batch, A_outer)


        x_main_batch = 0.0 * x_batch
        z_opt_batch = np.random.randn(hparams.batch_size, 100)
        for k in range(maxiter):

            x_est_batch=x_main_batch + hparams.outer_learning_rate*(np.matmul((y_batch_outer-np.matmul(x_main_batch,A_outer)),A_outer.T))



            estimator = estimators['dcgan']
            x_hat_batch,z_opt_batch = estimator(x_est_batch,z_opt_batch, hparams)
            x_main_batch=x_hat_batch


        for i, key in enumerate(x_batch_dict.keys()):
            x = xs_dict[key]
            y = y_batch_outer[i]
            x_hat = x_hat_batch[i]

            # Save the estimate
            x_hats_dict['dcgan'][key] = x_hat

            # Compute and store measurement and l2 loss
            measurement_losses['dcgan'][key] = utils.get_measurement_loss(x_hat, A_outer, y)
            l2_losses['dcgan'][key] = utils.get_l2_loss(x_hat, x)
        print 'Processed upto image {0} / {1}'.format(key+1, len(xs_dict))

        # Checkpointing
        if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
            #x_hats_dict = {'dcgan' : {}}
            print '\nProcessed and saved first ', key+1, 'images\n'

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print model_type
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'
示例#4
0
def main():
    """Make and save image matrices"""
    hparams = Hparams()
    xs_dict = celebA_input.model_input(hparams)
    start, stop = 0, 5
    images_nums = get_image_nums(start, stop, hparams)
    is_save = True

    def formatted(f):
        return format(f, '.4f').rstrip('0').rstrip('.')

    #legend_base_regexs = [
    #    ('MAP',
    #            f'./estimated/ffhq-69000/full-input/circulant/{hparams.noise_std}/',
    #                 '/ncsnv2/map/*'),
    #    ('Deep-Decoder',
    #            f'./estimated/ffhq-69000/full-input/circulant/{hparams.noise_std}/',
    #                 '/dd/map/*'),
    #        ('Langevin',
    #                f'./estimated/ffhq-69000/full-input/circulant/{hparams.noise_std}/',
    #                     '/ncsnv2/langevin/*')

    #]
    #criterion = ['lpips', 'mean']
    legend_base_regexs = [
        ('MAP',
         f'./estimated*/celebA/full-input/circulant/{hparams.noise_std}/',
         '/glow*map*/*'),
        ('Modified-MAP',
         f'./estimated*/celebA/full-input/circulant/{hparams.noise_std}/',
         '/glow*map*/*'),
        ('Langevin(Ours)',
         f'./estimated*/celebA/full-input/circulant/{hparams.noise_std}/',
         '/glow*langevin/None_None*'),
    ]
    retrieve_list = [['lpips', 'mean'], ['lpips', 'std']]

    for num_measurements in [2500, 5000, 10000, 15000, 20000, 30000, 35000]:
        #for num_measurements in [5000,10000,15000,20000,40000,50000,75000]:
        #for num_measurements in [100,200,500,1000,2500,5000,7500,10000]:
        patterns_images, patterns_images_2, patterns_lpips, patterns_l2 = [], [] , [], []
        exists = True
        for legend, base, regex in legend_base_regexs:
            keys = map(int_or_float,
                       [a.split('/')[-1] for a in glob.glob(base + '*')])
            list_keys = [key for key in keys]
            print(list_keys)
            if num_measurements not in list_keys:
                exists = False
                break
            pattern = base + str(num_measurements) + regex
            if 'glow' in regex and legend in ['MAP', 'Langevin']:
                criterion = ['likelihood', 'mean']
            else:
                criterion = ['l2', 'mean']

            _, best_dir = find_best(pattern, criterion, retrieve_list)
            print(best_dir)
            pattern_images = best_dir + '/{0}.png'
            pattern_images_2 = best_dir + '/images/{:06d}.png'
            pattern_lpips = best_dir + '/lpips_scores.pkl'
            pattern_l2 = best_dir + '/l2_losses.pkl'
            patterns_images.append(pattern_images)
            patterns_images_2.append(pattern_images_2)
            patterns_lpips.append(pattern_lpips)
            patterns_l2.append(pattern_l2)
        print(patterns_images)
        if exists:
            try:
                view(xs_dict, patterns_images, patterns_lpips, patterns_l2,
                     images_nums, hparams)
            except FileNotFoundError:
                view(xs_dict, patterns_images_2, patterns_lpips, patterns_l2,
                     images_nums, hparams)
            except FileNotFoundError:
                pass

            # patterns = [pattern2, pattern3]
            # view(xs_dict, patterns, images_nums, hparams)
            #save_path = f'./results/ffhq-69000_reconstr_{num_measurements}_{criterion[0]}_ncsnv2_orig_map_langevin.pdf'
            save_path = f'./results/celebA_reconstr_{num_measurements}_{criterion[0]}_ncsnv2_orig_map_langevin.pdf'
            utils.save_plot(is_save, save_path)
        else:
            continue
示例#5
0
def main(hparams):
    hparams.n_input = np.prod(hparams.image_shape)
    maxiter = hparams.max_outer_iter
    utils.print_hparams(hparams)
    xs_dict = model_input(hparams)
    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)
    x_hats_dict = {'dcgan': {}}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue
        x_coll = [
            x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()
        ]
        x_batch = np.concatenate(x_coll)
        A_outer = utils.get_outer_A(hparams)
        # 1bitify
        y_batch_outer = np.sign(np.matmul(x_batch, A_outer))

        x_main_batch = 0.0 * x_batch
        z_opt_batch = np.random.randn(hparams.batch_size, 100)
        for k in range(maxiter):
            x_est_batch = x_main_batch + hparams.outer_learning_rate * (
                np.matmul(
                    (y_batch_outer -
                     np.sign(np.matmul(x_main_batch, A_outer))), A_outer.T))
            estimator = estimators['dcgan']
            x_hat_batch, z_opt_batch = estimator(x_est_batch, z_opt_batch,
                                                 hparams)
            x_main_batch = x_hat_batch

        for i, key in enumerate(x_batch_dict.keys()):
            x = xs_dict[key]
            y = y_batch_outer[i]
            x_hat = x_hat_batch[i]
            x_hats_dict['dcgan'][key] = x_hat
            measurement_losses['dcgan'][key] = utils.get_measurement_loss(
                x_hat, A_outer, y)
            l2_losses['dcgan'][key] = utils.get_l2_loss(x_hat, x)
        print 'Processed upto image {0} / {1}'.format(key + 1, len(xs_dict))
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                             save_image, hparams)
            print '\nProcessed and saved first ', key + 1, 'images\n'

        x_batch_dict = {}

    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                         save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print model_type
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(
            len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'
def main():
    """Make and save image matrices"""
    hparams = Hparams()
    xs_dict = celebA_input.model_input(hparams)
    start, stop = 0, 5
    images_nums = get_image_nums(start, stop, hparams)
    is_save = True

    def formatted(f):
        return format(f, '.4f').rstrip('0').rstrip('.')

    legend_base_regexs = [
        ('MAP', './estimated/celebA/full-input/circulant/4.0/',
         '/realnvp/annealed_map/*'),
        ('Langevin', './estimated/celebA/full-input/circulant/4.0/',
         '/realnvp/annealed_langevin/*')
    ]
    criterion = ['l2', 'mean']
    retrieve_list = [['l2', 'mean'], ['l2', 'std']]

    for num_measurements in [100, 200, 500, 1000, 2500, 5000, 7500, 10000]:
        patterns_images, patterns_lpips, patterns_l2 = [], [], []
        exists = True
        for legend, base, regex in legend_base_regexs:
            keys = map(int_or_float,
                       [a.split('/')[-1] for a in glob.glob(base + '*')])
            list_keys = [key for key in keys]
            if num_measurements not in list_keys:
                exists = False
                break
            pattern = base + str(num_measurements) + regex
            _, best_dir = find_best(pattern, criterion, retrieve_list)
            pattern_images = best_dir + '/{0}.png'
            pattern_lpips = best_dir + '/lpips_scores.pkl'
            pattern_l2 = best_dir + '/l2_losses.pkl'
            patterns_images.append(pattern_images)
            patterns_lpips.append(pattern_lpips)
            patterns_l2.append(pattern_l2)
    # for num_measurements in [100, 250, 500, 1000, 2500,5000,7500, 10000]:
    #     pattern1_base = './estimated/celebA/full-input/circulant/4.0/' + str(num_measurements) + '/realnvp/annealed_map/None_200.0_10.0_20.0_4.0_False_sgd_0.001_0.0_2000_1/'
    #     pattern1_images = pattern1_base + '{0}.png'
    #     pattern1_lpips = pattern1_base + 'lpips_scores.pkl'
    #     pattern1_l2 = pattern1_base + 'l2_losses.pkl'
    #     # pattern2 = './estimated/celebA/full-input/circulant/16.0/' + str(num_measurements) + '/glow_map/1.0_0.0_0.01024_adam_0.001_0.0_2000_2/{0}.png'
    #     # pattern3 = './estimated/celebA/full-input/circulant/16.0/' + str(num_measurements) + '/glow_langevin/1.0_0.0_1.0204_sgd_1e-05_0.0_3001_1/{0}.png'
    #     # pattern2 = './estimated/celebA/full-input/gaussian/5.477/' + str(num_measurements) + '/map/1.0_0.012_0.0_adam_0.01_0.0_2000_2/{0}.png'
    #     pattern2_base = './estimated/celebA/full-input/circulant/4.0/' + str(num_measurements) + '/realnvp/annealed_langevin/None_None_200.0_10.0_20.0_4.0_False_sgd_0.0005_0.0_2000_1/'
    #     pattern2_images = pattern2_base + '{0}.png'
    #     pattern2_lpips = pattern2_base + 'lpips_scores.pkl'
    #     pattern2_l2 = pattern2_base + 'l2_losses.pkl'
    #     # if num_measurements == 5000:
    #     #     pattern3_base = './estimated_backup_old/celebA/full-input/gaussian/4.0/5000/langevin/1.0_0.0064_0.0_sgd_0.0001_0.0_1000_2/'
    #     # else:
    #     #     pattern3_base = './estimated_backup_old/celebA/full-input/gaussian/4.0/' + str(num_measurements) + '/langevin/1.0_' + formatted(32/num_measurements) + '_0.0_sgd_0.001_0.0_2000_2/'
    #     # pattern3_images = pattern3_base + '{0}.png'
    #     # pattern3_lpips = pattern3_base + 'lpips_scores.pkl'
    #     # pattern3_l2 = pattern3_base + 'l2_losses.pkl'
    #     # pattern4 = './estimated/celebA/full-input/gaussian/4.0/' + str(num_measurements) + '/langevin/1.0_0.0064_0.0_sgd_0.001_0.0_2000_1/{0}.png'
    #     # pattern3 = './estimated/celebA/full-input/gaussian/5.477/' + str(num_measurements) + '/langevin/1.0_0.03_0.0_sgd_0.0001_0.0_1000_2/{0}.png'
    #     # pattern4 = './estimated/celebA/full-input/gaussian/5.477/' + str(num_measurements) + '/langevin/1.0_0.03_0.0_sgd_0.0001_0.0_1000_2/{0}.png'
    #     patterns_images = [pattern1_images, pattern2_images]
    #     patterns_lpips = [pattern1_lpips, pattern2_lpips ]
    #     patterns_l2 = [pattern1_l2, pattern2_l2]
    # try:
        print(patterns_images)
        if exists:
            view(xs_dict, patterns_images, patterns_lpips, patterns_l2,
                 images_nums, hparams)
            # patterns = [pattern2, pattern3]
            # view(xs_dict, patterns, images_nums, hparams)
            save_path = f'./results/celebA_reconstr_{num_measurements}_{criterion[0]}_nvp_orig_map_langevin.pdf'
            utils.save_plot(is_save, save_path)
        else:
            continue