def main(): """Make and save image matrices""" hparams = Hparams() xs_dict = celebA_input.model_input(hparams) start, stop = 20, 30 images_nums = get_image_nums(start, stop, hparams) is_save = True for num_measurements in [50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]: pattern1 = './estimated/celebA/full-input/gaussian/0.01/' + str( num_measurements) + '/lasso-dct/0.1/{0}.png' pattern2 = './estimated/celebA/full-input/gaussian/0.01/' + str( num_measurements) + '/lasso-wavelet/1e-05/{0}.png' pattern3 = './estimated/celebA/full-input/gaussian/0.01/' + str( num_measurements ) + '/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_500_10/{0}.png' patterns = [pattern1, pattern2, pattern3] view(xs_dict, patterns, images_nums, hparams) base_path = './results/celebA_reconstr_{}_orig_lasso-dct_lasso-wavelet_dcgan.pdf' save_path = base_path.format(num_measurements) utils.save_plot(is_save, save_path)
def main(hparams): # set up perceptual loss device = 'cuda:0' percept = PerceptualLoss( model="net-lin", net="vgg", use_gpu=device.startswith("cuda") ) utils.print_hparams(hparams) # get inputs xs_dict = model_input(hparams) estimators = utils.get_estimators(hparams) utils.setup_checkpointing(hparams) measurement_losses, l2_losses, lpips_scores, z_hats = utils.load_checkpoints(hparams) x_hats_dict = {model_type : {} for model_type in hparams.model_types} x_batch_dict = {} A = utils.get_A(hparams) noise_batch = hparams.noise_std * np.random.standard_t(2, size=(hparams.batch_size, hparams.num_measurements)) for key, x in xs_dict.items(): if not hparams.not_lazy: # If lazy, first check if the image has already been # saved before by *all* estimators. If yes, then skip this image. save_paths = utils.get_save_paths(hparams, key) is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()]) if is_saved: continue x_batch_dict[key] = x if len(x_batch_dict) < hparams.batch_size: continue # Reshape input x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.items()] x_batch = np.concatenate(x_batch_list) # Construct noise and measurements y_batch = utils.get_measurements(x_batch, A, noise_batch, hparams) # Construct estimates using each estimator for model_type in hparams.model_types: estimator = estimators[model_type] x_hat_batch, z_hat_batch, m_loss_batch = estimator(A, y_batch, hparams) for i, key in enumerate(x_batch_dict.keys()): x = xs_dict[key] y_train = y_batch[i] x_hat = x_hat_batch[i] # Save the estimate x_hats_dict[model_type][key] = x_hat # Compute and store measurement and l2 loss measurement_losses[model_type][key] = m_loss_batch[key] l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x) lpips_scores[model_type][key] = utils.get_lpips_score(percept, x_hat, x, hparams.image_shape) z_hats[model_type][key] = z_hat_batch[i] print('Processed upto image {0} / {1}'.format(key+1, len(xs_dict))) # Checkpointing if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0): utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams) x_hats_dict = {model_type : {} for model_type in hparams.model_types} print('\nProcessed and saved first ', key+1, 'images\n') x_batch_dict = {} # Final checkpoint if hparams.save_images: utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams) print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))) if hparams.print_stats: for model_type in hparams.model_types: print(model_type) measurement_loss_list = list(measurement_losses[model_type].values()) l2_loss_list = list(l2_losses[model_type].values()) mean_m_loss = np.mean(measurement_loss_list) mean_l2_loss = np.mean(l2_loss_list) print('mean measurement loss = {0}'.format(mean_m_loss)) print('mean l2 loss = {0}'.format(mean_l2_loss)) if hparams.image_matrix > 0: utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams) # Warn the user that some things were not processsed if len(x_batch_dict) > 0: print('\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict))) print('Consider rerunning lazily with a smaller batch size.')
def main(hparams): # Set up some stuff according to hparams hparams.n_input = np.prod(hparams.image_shape) maxiter = hparams.max_outer_iter utils.print_hparams(hparams) # get inputs xs_dict = model_input(hparams) estimators = utils.get_estimators(hparams) utils.setup_checkpointing(hparams) measurement_losses, l2_losses = utils.load_checkpoints(hparams) x_hats_dict = {'dcgan' : {}} x_batch_dict = {} for key, x in xs_dict.iteritems(): if hparams.lazy: # If lazy, first check if the image has already been # saved before by *all* estimators. If yes, then skip this image. save_paths = utils.get_save_paths(hparams, key) is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()]) if is_saved: continue x_batch_dict[key] = x if len(x_batch_dict) < hparams.batch_size: continue # Reshape input x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()] x_batch = np.concatenate(x_batch_list) # Construct measurements A_outer = utils.get_outer_A(hparams) y_batch_outer=np.matmul(x_batch, A_outer) x_main_batch = 0.0 * x_batch z_opt_batch = np.random.randn(hparams.batch_size, 100) for k in range(maxiter): x_est_batch=x_main_batch + hparams.outer_learning_rate*(np.matmul((y_batch_outer-np.matmul(x_main_batch,A_outer)),A_outer.T)) estimator = estimators['dcgan'] x_hat_batch,z_opt_batch = estimator(x_est_batch,z_opt_batch, hparams) x_main_batch=x_hat_batch for i, key in enumerate(x_batch_dict.keys()): x = xs_dict[key] y = y_batch_outer[i] x_hat = x_hat_batch[i] # Save the estimate x_hats_dict['dcgan'][key] = x_hat # Compute and store measurement and l2 loss measurement_losses['dcgan'][key] = utils.get_measurement_loss(x_hat, A_outer, y) l2_losses['dcgan'][key] = utils.get_l2_loss(x_hat, x) print 'Processed upto image {0} / {1}'.format(key+1, len(xs_dict)) # Checkpointing if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0): utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams) #x_hats_dict = {'dcgan' : {}} print '\nProcessed and saved first ', key+1, 'images\n' x_batch_dict = {} # Final checkpoint if hparams.save_images: utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams) print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict)) if hparams.print_stats: for model_type in hparams.model_types: print model_type mean_m_loss = np.mean(measurement_losses[model_type].values()) mean_l2_loss = np.mean(l2_losses[model_type].values()) print 'mean measurement loss = {0}'.format(mean_m_loss) print 'mean l2 loss = {0}'.format(mean_l2_loss) if hparams.image_matrix > 0: utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams) # Warn the user that some things were not processsed if len(x_batch_dict) > 0: print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict)) print 'Consider rerunning lazily with a smaller batch size.'
def main(): """Make and save image matrices""" hparams = Hparams() xs_dict = celebA_input.model_input(hparams) start, stop = 0, 5 images_nums = get_image_nums(start, stop, hparams) is_save = True def formatted(f): return format(f, '.4f').rstrip('0').rstrip('.') #legend_base_regexs = [ # ('MAP', # f'./estimated/ffhq-69000/full-input/circulant/{hparams.noise_std}/', # '/ncsnv2/map/*'), # ('Deep-Decoder', # f'./estimated/ffhq-69000/full-input/circulant/{hparams.noise_std}/', # '/dd/map/*'), # ('Langevin', # f'./estimated/ffhq-69000/full-input/circulant/{hparams.noise_std}/', # '/ncsnv2/langevin/*') #] #criterion = ['lpips', 'mean'] legend_base_regexs = [ ('MAP', f'./estimated*/celebA/full-input/circulant/{hparams.noise_std}/', '/glow*map*/*'), ('Modified-MAP', f'./estimated*/celebA/full-input/circulant/{hparams.noise_std}/', '/glow*map*/*'), ('Langevin(Ours)', f'./estimated*/celebA/full-input/circulant/{hparams.noise_std}/', '/glow*langevin/None_None*'), ] retrieve_list = [['lpips', 'mean'], ['lpips', 'std']] for num_measurements in [2500, 5000, 10000, 15000, 20000, 30000, 35000]: #for num_measurements in [5000,10000,15000,20000,40000,50000,75000]: #for num_measurements in [100,200,500,1000,2500,5000,7500,10000]: patterns_images, patterns_images_2, patterns_lpips, patterns_l2 = [], [] , [], [] exists = True for legend, base, regex in legend_base_regexs: keys = map(int_or_float, [a.split('/')[-1] for a in glob.glob(base + '*')]) list_keys = [key for key in keys] print(list_keys) if num_measurements not in list_keys: exists = False break pattern = base + str(num_measurements) + regex if 'glow' in regex and legend in ['MAP', 'Langevin']: criterion = ['likelihood', 'mean'] else: criterion = ['l2', 'mean'] _, best_dir = find_best(pattern, criterion, retrieve_list) print(best_dir) pattern_images = best_dir + '/{0}.png' pattern_images_2 = best_dir + '/images/{:06d}.png' pattern_lpips = best_dir + '/lpips_scores.pkl' pattern_l2 = best_dir + '/l2_losses.pkl' patterns_images.append(pattern_images) patterns_images_2.append(pattern_images_2) patterns_lpips.append(pattern_lpips) patterns_l2.append(pattern_l2) print(patterns_images) if exists: try: view(xs_dict, patterns_images, patterns_lpips, patterns_l2, images_nums, hparams) except FileNotFoundError: view(xs_dict, patterns_images_2, patterns_lpips, patterns_l2, images_nums, hparams) except FileNotFoundError: pass # patterns = [pattern2, pattern3] # view(xs_dict, patterns, images_nums, hparams) #save_path = f'./results/ffhq-69000_reconstr_{num_measurements}_{criterion[0]}_ncsnv2_orig_map_langevin.pdf' save_path = f'./results/celebA_reconstr_{num_measurements}_{criterion[0]}_ncsnv2_orig_map_langevin.pdf' utils.save_plot(is_save, save_path) else: continue
def main(hparams): hparams.n_input = np.prod(hparams.image_shape) maxiter = hparams.max_outer_iter utils.print_hparams(hparams) xs_dict = model_input(hparams) estimators = utils.get_estimators(hparams) utils.setup_checkpointing(hparams) measurement_losses, l2_losses = utils.load_checkpoints(hparams) x_hats_dict = {'dcgan': {}} x_batch_dict = {} for key, x in xs_dict.iteritems(): x_batch_dict[key] = x if len(x_batch_dict) < hparams.batch_size: continue x_coll = [ x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems() ] x_batch = np.concatenate(x_coll) A_outer = utils.get_outer_A(hparams) # 1bitify y_batch_outer = np.sign(np.matmul(x_batch, A_outer)) x_main_batch = 0.0 * x_batch z_opt_batch = np.random.randn(hparams.batch_size, 100) for k in range(maxiter): x_est_batch = x_main_batch + hparams.outer_learning_rate * ( np.matmul( (y_batch_outer - np.sign(np.matmul(x_main_batch, A_outer))), A_outer.T)) estimator = estimators['dcgan'] x_hat_batch, z_opt_batch = estimator(x_est_batch, z_opt_batch, hparams) x_main_batch = x_hat_batch for i, key in enumerate(x_batch_dict.keys()): x = xs_dict[key] y = y_batch_outer[i] x_hat = x_hat_batch[i] x_hats_dict['dcgan'][key] = x_hat measurement_losses['dcgan'][key] = utils.get_measurement_loss( x_hat, A_outer, y) l2_losses['dcgan'][key] = utils.get_l2_loss(x_hat, x) print 'Processed upto image {0} / {1}'.format(key + 1, len(xs_dict)) if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter == 0): utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams) print '\nProcessed and saved first ', key + 1, 'images\n' x_batch_dict = {} if hparams.save_images: utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams) print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict)) if hparams.print_stats: for model_type in hparams.model_types: print model_type mean_m_loss = np.mean(measurement_losses[model_type].values()) mean_l2_loss = np.mean(l2_losses[model_type].values()) print 'mean measurement loss = {0}'.format(mean_m_loss) print 'mean l2 loss = {0}'.format(mean_l2_loss) if hparams.image_matrix > 0: utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams) # Warn the user that some things were not processsed if len(x_batch_dict) > 0: print '\nDid NOT process last {} images because they did not fill up the last batch.'.format( len(x_batch_dict)) print 'Consider rerunning lazily with a smaller batch size.'
def main(): """Make and save image matrices""" hparams = Hparams() xs_dict = celebA_input.model_input(hparams) start, stop = 0, 5 images_nums = get_image_nums(start, stop, hparams) is_save = True def formatted(f): return format(f, '.4f').rstrip('0').rstrip('.') legend_base_regexs = [ ('MAP', './estimated/celebA/full-input/circulant/4.0/', '/realnvp/annealed_map/*'), ('Langevin', './estimated/celebA/full-input/circulant/4.0/', '/realnvp/annealed_langevin/*') ] criterion = ['l2', 'mean'] retrieve_list = [['l2', 'mean'], ['l2', 'std']] for num_measurements in [100, 200, 500, 1000, 2500, 5000, 7500, 10000]: patterns_images, patterns_lpips, patterns_l2 = [], [], [] exists = True for legend, base, regex in legend_base_regexs: keys = map(int_or_float, [a.split('/')[-1] for a in glob.glob(base + '*')]) list_keys = [key for key in keys] if num_measurements not in list_keys: exists = False break pattern = base + str(num_measurements) + regex _, best_dir = find_best(pattern, criterion, retrieve_list) pattern_images = best_dir + '/{0}.png' pattern_lpips = best_dir + '/lpips_scores.pkl' pattern_l2 = best_dir + '/l2_losses.pkl' patterns_images.append(pattern_images) patterns_lpips.append(pattern_lpips) patterns_l2.append(pattern_l2) # for num_measurements in [100, 250, 500, 1000, 2500,5000,7500, 10000]: # pattern1_base = './estimated/celebA/full-input/circulant/4.0/' + str(num_measurements) + '/realnvp/annealed_map/None_200.0_10.0_20.0_4.0_False_sgd_0.001_0.0_2000_1/' # pattern1_images = pattern1_base + '{0}.png' # pattern1_lpips = pattern1_base + 'lpips_scores.pkl' # pattern1_l2 = pattern1_base + 'l2_losses.pkl' # # pattern2 = './estimated/celebA/full-input/circulant/16.0/' + str(num_measurements) + '/glow_map/1.0_0.0_0.01024_adam_0.001_0.0_2000_2/{0}.png' # # pattern3 = './estimated/celebA/full-input/circulant/16.0/' + str(num_measurements) + '/glow_langevin/1.0_0.0_1.0204_sgd_1e-05_0.0_3001_1/{0}.png' # # pattern2 = './estimated/celebA/full-input/gaussian/5.477/' + str(num_measurements) + '/map/1.0_0.012_0.0_adam_0.01_0.0_2000_2/{0}.png' # pattern2_base = './estimated/celebA/full-input/circulant/4.0/' + str(num_measurements) + '/realnvp/annealed_langevin/None_None_200.0_10.0_20.0_4.0_False_sgd_0.0005_0.0_2000_1/' # pattern2_images = pattern2_base + '{0}.png' # pattern2_lpips = pattern2_base + 'lpips_scores.pkl' # pattern2_l2 = pattern2_base + 'l2_losses.pkl' # # if num_measurements == 5000: # # pattern3_base = './estimated_backup_old/celebA/full-input/gaussian/4.0/5000/langevin/1.0_0.0064_0.0_sgd_0.0001_0.0_1000_2/' # # else: # # pattern3_base = './estimated_backup_old/celebA/full-input/gaussian/4.0/' + str(num_measurements) + '/langevin/1.0_' + formatted(32/num_measurements) + '_0.0_sgd_0.001_0.0_2000_2/' # # pattern3_images = pattern3_base + '{0}.png' # # pattern3_lpips = pattern3_base + 'lpips_scores.pkl' # # pattern3_l2 = pattern3_base + 'l2_losses.pkl' # # pattern4 = './estimated/celebA/full-input/gaussian/4.0/' + str(num_measurements) + '/langevin/1.0_0.0064_0.0_sgd_0.001_0.0_2000_1/{0}.png' # # pattern3 = './estimated/celebA/full-input/gaussian/5.477/' + str(num_measurements) + '/langevin/1.0_0.03_0.0_sgd_0.0001_0.0_1000_2/{0}.png' # # pattern4 = './estimated/celebA/full-input/gaussian/5.477/' + str(num_measurements) + '/langevin/1.0_0.03_0.0_sgd_0.0001_0.0_1000_2/{0}.png' # patterns_images = [pattern1_images, pattern2_images] # patterns_lpips = [pattern1_lpips, pattern2_lpips ] # patterns_l2 = [pattern1_l2, pattern2_l2] # try: print(patterns_images) if exists: view(xs_dict, patterns_images, patterns_lpips, patterns_l2, images_nums, hparams) # patterns = [pattern2, pattern3] # view(xs_dict, patterns, images_nums, hparams) save_path = f'./results/celebA_reconstr_{num_measurements}_{criterion[0]}_nvp_orig_map_langevin.pdf' utils.save_plot(is_save, save_path) else: continue