def vae_estimator(hparams): # Get a session sess = tf.Session() # Set up palceholders #A = tf.placeholder(tf.float32, shape=(hparams.batch_size, 100), name='A') y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.n_input), name='y_batch') # Create the generator # TODO: Move z_batch definition here z_batch,x_hat_batch, restore_path, restore_dict = mnist_model_def.vae_gen(hparams) # measure the estimate y_hat_batch = tf.identity(x_hat_batch,name='y2_batch') # define all losses m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1) m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1) zp_loss_batch = tf.reduce_sum(z_batch**2, 1) # define total loss total_loss_batch = hparams.mloss1_weight * m_loss1_batch \ + hparams.mloss2_weight * m_loss2_batch \ + hparams.zprior_weight * zp_loss_batch total_loss = tf.reduce_mean(total_loss_batch) # Compute means for logging m_loss1 = tf.reduce_mean(m_loss1_batch) m_loss2 = tf.reduce_mean(m_loss2_batch) zp_loss = tf.reduce_mean(zp_loss_batch) # Set up gradient descent var_list = [z_batch] global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = utils.get_learning_rate(global_step, hparams) opt = utils.get_optimizer(learning_rate, hparams) update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op') opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step) # Intialize and restore model parameters init_op = tf.global_variables_initializer() sess.run(init_op) restorer = tf.train.Saver(var_list=restore_dict) restorer.restore(sess, restore_path) def estimator(y_batch_val,z_batch_val,hparams): """Function that returns the estimated image""" best_keeper = utils.BestKeeper(hparams) assign_z_opt_op = z_batch.assign(z_batch_val) feed_dict = {y_batch: y_batch_val} for i in range(hparams.num_random_restarts): sess.run(opt_reinit_op) sess.run(assign_z_opt_op) for j in range(hparams.max_update_iter): _, lr_val, total_loss_val, \ m_loss1_val, \ m_loss2_val, \ zp_loss_val = sess.run([update_op, learning_rate, total_loss, m_loss1, m_loss2, zp_loss], feed_dict=feed_dict) logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}' print logging_format.format(i, j, lr_val, total_loss_val, m_loss1_val, m_loss2_val, zp_loss_val) x_hat_batch_val,z_batch_val, total_loss_batch_val = sess.run([x_hat_batch,z_batch, total_loss_batch], feed_dict=feed_dict) best_keeper.report(x_hat_batch_val,z_batch_val, total_loss_batch_val) return best_keeper.get_best() return estimator
def vae_estimator(hparams): # Get a session tf.reset_default_graph() g1 = tf.Graph() with g1.as_default() as g: gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options\ , allow_soft_placement=True)) # Set up palceholders A = tf.placeholder(tf.float32, shape=(hparams.n_input, hparams.num_measurements), name='A') y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.num_measurements), name='y_batch') # Create the generator z_batch, x_hat_batch, _, restore_path, restore_dict = construct_gen( hparams, vae_model_def, 'gen') # measure the estimate if hparams.measurement_type == 'project': y_hat_batch = tf.identity(x_hat_batch, name='y_hat_batch') else: y_hat_batch = tf.matmul(x_hat_batch, A, name='y_hat_batch') # define all losses m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1) m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1) zp_loss_batch = tf.reduce_sum(z_batch**2, 1) # define total loss total_loss_batch = hparams.mloss1_weight * m_loss1_batch \ + hparams.mloss2_weight * m_loss2_batch \ + hparams.zprior_weight * zp_loss_batch total_loss = tf.reduce_mean(total_loss_batch) # Compute means for logging m_loss1 = tf.reduce_mean(m_loss1_batch) m_loss2 = tf.reduce_mean(m_loss2_batch) zp_loss = tf.reduce_mean(zp_loss_batch) # Set up gradient descent var_list = [z_batch] global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = utils.get_learning_rate(global_step, hparams) opt = utils.get_optimizer(learning_rate, hparams) update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op') opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step) # Intialize and restore model parameters init_op = tf.global_variables_initializer() sess.run(init_op) restorer = tf.train.Saver(var_list=restore_dict) restorer.restore(sess, restore_path) def estimator(A_val, y_batch_val, hparams): """Function that returns the estimated image""" best_keeper = utils.BestKeeper(hparams) if hparams.measurement_type == 'project': feed_dict = {y_batch: y_batch_val} else: feed_dict = {A: A_val, y_batch: y_batch_val} for i in range(hparams.num_random_restarts): sess.run(opt_reinit_op) for j in range(hparams.max_update_iter): _, lr_val, total_loss_val, \ m_loss1_val, \ m_loss2_val, \ zp_loss_val = sess.run([update_op, learning_rate, total_loss, m_loss1, m_loss2, zp_loss], feed_dict=feed_dict) x_hat_batch_val, total_loss_batch_val = sess.run( [x_hat_batch, total_loss_batch], feed_dict=feed_dict) best_keeper.report(x_hat_batch_val, total_loss_batch_val) return best_keeper.get_best() return estimator
def dcgan_l1_estimator(hparams, model_type): # pylint: disable = C0326 tf.reset_default_graph() g1 = tf.Graph() with g1.as_default() as g: # Set up palceholders A = tf.placeholder(tf.float32, shape=(hparams.n_input, hparams.num_measurements), name='A') y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.num_measurements), name='y_batch') # Create the generator z_batch = tf.Variable(tf.random_normal([hparams.batch_size, 100]), name='z_batch') x_hat_batch, restore_dict_gen, restore_path_gen = dcgan_gen( z_batch, hparams) # Create the discriminator prob, restore_dict_discrim, restore_path_discrim = dcgan_discrim( x_hat_batch, hparams) nu_estim = tf.get_variable("x_estim", dtype=tf.float32, shape=x_hat_batch.get_shape(), initializer=tf.constant_initializer(0)) x_estim = nu_estim + x_hat_batch # measure the estimate if hparams.measurement_type == 'project': y_hat_batch = tf.identity(x_estim, name='y2_batch') else: y_hat_batch = tf.matmul(x_estim, A, name='y2_batch') # define all losses m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1) m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1) zp_loss_batch = tf.reduce_sum(z_batch**2, 1) d_loss1_batch = -tf.log(prob) d_loss2_batch = tf.log(1 - prob) if model_type == 'dcgan_l1': l1_loss = tf.reduce_sum(tf.abs(nu_estim), 1) elif model_type == 'dcgan_l1_wavelet': W = wavelet_basis() Winv = np.linalg.inv(W) l1_loss = tf.reduce_sum( tf.abs(tf.matmul(nu_estim, tf.constant(Winv, dtype=tf.float32))), 1) elif model_type == 'dcgan_l1_dct': dct_proj = np.reshape( np.array([ dct2(np.eye(64)) for itr in range(hparams.batch_size * 3) ]), [hparams.batch_size, 3, 64, 64]) nu_re = tf.transpose(tf.reshape(nu_estim, (-1, 64, 64, 3)), [0, 3, 1, 2]) l1_loss = tf.reduce_sum( tf.abs( tf.matmul(nu_re, tf.constant(dct_proj, dtype=tf.float32))), [1, 2, 3]) # define total loss total_loss_batch = hparams.mloss1_weight * m_loss1_batch \ + hparams.mloss2_weight * m_loss2_batch \ + hparams.zprior_weight * zp_loss_batch \ + hparams.dloss1_weight * d_loss1_batch \ + hparams.dloss2_weight * d_loss2_batch \ + hparams.sparse_gen_weight * l1_loss total_loss = tf.reduce_mean(total_loss_batch) # Compute means for logging m_loss1 = tf.reduce_mean(m_loss1_batch) m_loss2 = tf.reduce_mean(m_loss2_batch) zp_loss = tf.reduce_mean(zp_loss_batch) d_loss1 = tf.reduce_mean(d_loss1_batch) d_loss2 = tf.reduce_mean(d_loss2_batch) # Set up gradient descent z_batch, var_list = [nu_estim, z_batch] global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = utils.get_learning_rate(global_step, hparams) with tf.variable_scope(tf.get_variable_scope(), reuse=False): opt = utils.get_optimizer(learning_rate, hparams) update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op') update_init_op = opt.minimize(total_loss, var_list=[z_batch], name='update_init_op') nu_estim_clip = nu_estim.assign( tf.maximum(tf.minimum(1.0 - x_hat_batch, nu_estim), -1.0 - x_hat_batch)) opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step) # Intialize and restore model parameters init_op = tf.global_variables_initializer() # Get a session gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(graph=g1, config=tf.ConfigProto(gpu_options=gpu_options\ , allow_soft_placement=True)) sess.run(init_op) restorer_gen = tf.train.Saver(var_list=restore_dict_gen) restorer_discrim = tf.train.Saver(var_list=restore_dict_discrim) restorer_gen.restore(sess, restore_path_gen) restorer_discrim.restore(sess, restore_path_discrim) def estimator(A_val, y_batch_val, hparams): """Function that returns the estimated image""" best_keeper = utils.BestKeeper(hparams) if hparams.measurement_type == 'project': feed_dict = {y_batch: y_batch_val} else: feed_dict = {A: A_val, y_batch: y_batch_val} for i in range(hparams.num_random_restarts): sess.run(opt_reinit_op) if hparams.max_update_iter > 250: init_itr_no = 250 else: init_itr_no = 0 for j in range(init_itr_no): sess.run([update_init_op], feed_dict=feed_dict) x_estim_val, total_loss_batch_val = sess.run( [x_estim, total_loss_batch], feed_dict=feed_dict) best_keeper.report(x_estim_val, total_loss_batch_val) for j in range(int(hparams.max_update_iter - init_itr_no)): _, lr_val, total_loss_val, \ m_loss1_val, \ m_loss2_val, \ zp_loss_val, \ d_loss1_val, \ d_loss2_val = sess.run([update_op, learning_rate, total_loss, m_loss1, m_loss2, zp_loss, d_loss1, d_loss2], feed_dict=feed_dict) sess.run(nu_estim_clip) x_estim_val, total_loss_batch_val = sess.run( [x_estim, total_loss_batch], feed_dict=feed_dict) best_keeper.report(x_estim_val, total_loss_batch_val) return best_keeper.get_best() return estimator
def dcgan_estimator(hparams): sess = tf.Session() y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.n_input), name='y_batch') z_batch = tf.Variable(tf.random_normal([hparams.batch_size, 100]), name='z_batch') x_hat_batch, restore_dict_gen, restore_path_gen = celebA_model_def.dcgan_gen( z_batch, hparams) prob, restore_dict_discrim, restore_path_discrim = celebA_model_def.dcgan_discrim( x_hat_batch, hparams) y_hat_batch = tf.zeros(x_hat_batch, name='y2_batch') m_loss1_batch = tf.abs(tf.reduce_mean(y_batch - y_hat_batch), 1) m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1) zp_loss_batch = tf.reduce_sum(z_batch**2, 1) d_loss1_batch = -tf.log(prob) d_loss2_batch = tf.log(1 - prob) m_loss1 = tf.reduce_mean(m_loss1_batch) m_loss2 = tf.reduce_mean(m_loss2_batch) zp_loss = tf.reduce_mean(zp_loss_batch) d_loss1 = tf.reduce_mean(d_loss1_batch) d_loss2 = tf.reduce_mean(d_loss2_batch) total_loss_batch = hparams.mloss1_weight * m_loss1_batch \ + hparams.mloss2_weight * m_loss2_batch \ + hparams.zprior_weight * zp_loss_batch \ + hparams.dloss1_weight * d_loss1_batch \ + hparams.dloss2_weight * d_loss2_batch var_list = [z_batch] global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = utils.get_learning_rate(global_step, hparams) with tf.variable_scope(tf.get_variable_scope(), reuse=False): opt = utils.get_optimizer(learning_rate, hparams) update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op') opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step) init_op = tf.global_variables_initializer() sess.run(init_op) restorer_gen = tf.train.Saver(var_list=restore_dict_gen) restorer_discrim = tf.train.Saver(var_list=restore_dict_discrim) restorer_gen.restore(sess, restore_path_gen) restorer_discrim.restore(sess, restore_path_discrim) def estimator(y_batch_val, z_batch_val, hparams): """Function that returns the estimated image""" best_keeper = utils.BestKeeper(hparams) assign_z_opt_op = z_batch.assign(z_batch_val) feed_dict = {y_batch: y_batch_val} for i in range(hparams.num_random_restarts): sess.run(opt_reinit_op) sess.run(assign_z_opt_op) for j in range(hparams.max_update_iter): _, lr_val, total_loss_val, \ m_loss1_val, \ m_loss2_val, \ zp_loss_val, \ d_loss1_val, \ d_loss2_val = sess.run([update_op, learning_rate, total_loss, m_loss1, m_loss2, zp_loss, d_loss1, d_loss2], feed_dict=feed_dict) logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {} d_loss1 {} d_loss2 {}' print logging_format.format(i, j, lr_val, total_loss_val, m_loss1_val, m_loss2_val, zp_loss_val, d_loss1_val, d_loss2_val) x_hat_batch_val, z_batch_val, total_loss_batch_val = sess.run( [x_hat_batch, z_batch, total_loss_batch], feed_dict=feed_dict) best_keeper.report(x_hat_batch_val, z_batch_val, total_loss_batch_val) return best_keeper.get_best() return estimator
def vae_estimator(hparams): # Get a session sess = tf.Session() # Set up palceholders A = tf.placeholder(tf.float32, shape=(hparams.n_input, hparams.num_measurements), name='A') y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.num_measurements), name='y_batch') # Create the generator # TODO: Move z_batch definition here z_batch, x_hat_batch, restore_path, restore_dict, _ = mnist_model_def.vae_gen( hparams) # measure the estimate if hparams.measurement_type == 'project': y_hat_batch = tf.identity(x_hat_batch, name='y_hat_batch') else: y_hat_batch = tf.matmul(x_hat_batch, A, name='y_hat_batch') # define all losses m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1) m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1) #zp_loss_batch = tf.reduce_sum(z_batch**2, 1) if hparams.stdv > 0: norm_val = 1 / (hparams.stdv**2) else: norm_val = 1e+20 zp_loss_batch = tf.reduce_sum( (z_batch - tf.ones(tf.shape(z_batch)) * hparams.mean)**2 * norm_val, 1) #added normalization # define total loss total_loss_batch = hparams.mloss1_weight * m_loss1_batch \ + hparams.mloss2_weight * m_loss2_batch \ + hparams.zprior_weight * zp_loss_batch total_loss = tf.reduce_mean(total_loss_batch) # Compute means for logging m_loss1 = tf.reduce_mean(m_loss1_batch) m_loss2 = tf.reduce_mean(m_loss2_batch) zp_loss = tf.reduce_mean(zp_loss_batch) # Set up gradient descent var_list = [z_batch] global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = utils.get_learning_rate(global_step, hparams) opt = utils.get_optimizer(learning_rate, hparams) update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op') opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step) # Intialize and restore model parameters init_op = tf.global_variables_initializer() sess.run(init_op) restorer = tf.train.Saver(var_list=restore_dict) restorer.restore(sess, restore_path) def estimator(A_val, y_batch_val, hparams): """Function that returns the estimated image""" best_keeper = utils.BestKeeper(hparams) if hparams.measurement_type == 'project': # if y_batch_val.shape[0]!=hparams.batch_size: # y_batch_val_tmp = np.zeros((hparams.batch_size,hparams.num_measurements)) # y_batch_val_tmp[:y_batch_val.shape[0],:] = y_batch_val # y_batch_val = y_batch_val_tmp # print('Smaller INPUT NUMBER')#Or change hparams on the fly feed_dict = {y_batch: y_batch_val} else: feed_dict = {A: A_val, y_batch: y_batch_val} for i in range(hparams.num_random_restarts): sess.run(opt_reinit_op) for j in range(hparams.max_update_iter): _, lr_val, total_loss_val, \ m_loss1_val, \ m_loss2_val, \ zp_loss_val = sess.run([update_op, learning_rate, total_loss, m_loss1, m_loss2, zp_loss], feed_dict=feed_dict) logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}' print( logging_format.format(i, j, lr_val, total_loss_val, m_loss1_val, m_loss2_val, zp_loss_val)) #print('n_z is {}'.format(hparams.n_z)) if total_loss_val == m_loss2_val and zp_loss_val > 0 and hparams.zprior_weight > 0: raise ValueError('NONONO') if hparams.gif and ((j % hparams.gif_iter) == 0): images = sess.run(x_hat_batch, feed_dict=feed_dict) for im_num, image in enumerate(images): save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num) utils.set_up_dir(save_dir) save_path = save_dir + '{0}.png'.format(j) image = image.reshape(hparams.image_shape) save_image(image, save_path) x_hat_batch_val, total_loss_batch_val = sess.run( [x_hat_batch, total_loss_batch], feed_dict=feed_dict) best_keeper.report(x_hat_batch_val, total_loss_batch_val) return best_keeper.get_best() return estimator
def stage_i(A_val,y_batch_val,hparams,hid_i,init_obj,early_stop,bs,optim,recovered=False): model_def = globals()['model_def'] m_loss1_batch_dict = {} m_loss2_batch_dict = {} zp_loss_batch_dict = {} total_loss_dict = {} x_hat_batch_dict = {} model_selection = ModelSelect(hparams) hid_i=int(hid_i) # print('Matrix norm is {}'.format(np.linalg.norm(A_val))) # hparams.eps = hparams.eps * np.linalg.norm(A_val) # Get a session sess = tf.Session() # Set up palceholders A = tf.placeholder(tf.float32, shape=(hparams.n_input, hparams.num_measurements), name='A') y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.num_measurements), name='y_batch') # Create the generator model_hparams = model_def.Hparams() model_hparams.n_z = hparams.n_z model_hparams.stdv = hparams.stdv model_hparams.mean = hparams.mean model_hparams.grid = copy.deepcopy(hparams.grid) model_selection.setup_dim(hid_i,model_hparams) if not hparams.model_types[0] == 'vae-flex-alt' and 'alt' in hparams.model_types[0]: model_def.ignore_grid = next((j for j in model_selection.dim_list if j >= hid_i), None) #set up the initialization print('The initialization is: {}'.format(init_obj.mode)) if init_obj.mode=='random': z_batch = model_def.get_z_var(model_hparams,hparams.batch_size,hid_i) elif init_obj.mode in ['previous-and-random','only-previous']: z_batch = model_def.get_z_var(model_hparams,hparams.batch_size,hid_i) init_op_par = tf.assign(z_batch, truncate_val(model_hparams,hparams,hid_i,init_obj,stdv=0)) else: z_batch = truncate_val(model_hparams,hparams,hid_i,init_obj,stdv=0.1) _, x_hat_batch, _ = model_def.generator_i(model_hparams, z_batch, 'gen', hparams.bol,hid_i,relative=False) x_hat_batch_dict[hid_i] = x_hat_batch # measure the estimate if hparams.measurement_type == 'project': y_hat_batch = tf.identity(x_hat_batch, name='y_hat_batch') else: y_hat_batch = tf.matmul(x_hat_batch, A, name='y_hat_batch') # define all losses m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1) m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1) if hparams.stdv>0: norm_val = 1/(hparams.stdv**2) else: norm_val = 1e+20 zp_loss_batch = tf.reduce_sum((z_batch-tf.ones(tf.shape(z_batch))*hparams.mean)**2*norm_val, 1) #added normalization # define total loss total_loss_batch = hparams.mloss1_weight * m_loss1_batch \ + hparams.mloss2_weight * m_loss2_batch \ + hparams.zprior_weight * zp_loss_batch total_loss = tf.reduce_mean(total_loss_batch) total_loss_dict[hid_i] = total_loss # Compute means for logging m_loss1 = tf.reduce_mean(m_loss1_batch) m_loss2 = tf.reduce_mean(m_loss2_batch) zp_loss = tf.reduce_mean(zp_loss_batch) m_loss1_batch_dict[hid_i] = m_loss1 m_loss2_batch_dict[hid_i] = m_loss2 zp_loss_batch_dict[hid_i] = zp_loss # Set up gradient descent var_list = [z_batch] if recovered: global_step = tf.Variable(hparams.optim.global_step, trainable=False, name='global_step') else: global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = utils.get_learning_rate(global_step, hparams) opt = utils.get_optimizer(learning_rate, hparams) update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op') opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step) # Intialize and restore model parameters init_op = tf.global_variables_initializer() sess.run(init_op) #restore the setting if 'alt' in hparams.model_types[0]: factor = 1 else: factor = len(hparams.grid) model_def.batch_size = hparams.batch_size*factor #changes object (call by reference), necessary, since call of generator_i might change batch size. model_selection.restore(sess,hid_i) if recovered: best_keeper = hparams.optim.best_keeper else: best_keeper = utils.BestKeeper(hparams,logg_z=True) if hparams.measurement_type == 'project': feed_dict = {y_batch: y_batch_val} else: feed_dict = {A: A_val, y_batch: y_batch_val} flag = False for i in range(init_obj.num_random_restarts): if recovered and i <= hparams.optim.i: #Loosing optimizer's state, keras implementation maybe better if i < hparams.optim.i: continue else: sess.run(utils.get_opt_reinit_op(opt, [], global_step)) sess.run(tf.assign(z_batch,hparams.optim.z_batch)) else: sess.run(opt_reinit_op) if i<1 and init_obj.mode in ['previous-and-random','only-previous']: print('Using previous outcome as starting point') sess.run(init_op_par) for j in range(hparams.max_update_iter): if recovered and j < hparams.optim.j: continue _, lr_val, total_loss_val, \ m_loss1_val, \ m_loss2_val, \ zp_loss_val = sess.run([update_op, learning_rate, total_loss, m_loss1, m_loss2, zp_loss], feed_dict=feed_dict) if hparams.gif and ((j % hparams.gif_iter) == 0): images = sess.run(x_hat_batch, feed_dict=feed_dict) for im_num, image in enumerate(images): save_dir = '{0}/{1}/{2}/'.format(hparams.gif_dir, hid_i,im_num) utils.set_up_dir(save_dir) save_path = save_dir + '{0}.png'.format(j) image = image.reshape(hparams.image_shape) save_image(image, save_path) if j%100==0 and early_stop: x_hat_batch_val = sess.run(x_hat_batch, feed_dict=feed_dict) if check_tolerance(hparams,A_val,x_hat_batch_val,y_batch_val)[1]: flag = True print('Early stopping') break if j%25==0:#Now not every turn logging_format = 'hid {} rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}' print( logging_format.format(hid_i, i, j, lr_val, total_loss_val, m_loss1_val, m_loss2_val, zp_loss_val)) if j%100==0: x_hat_batch_val, total_loss_batch_val, z_batch_val = sess.run([x_hat_batch, total_loss_batch,z_batch], feed_dict=feed_dict) best_keeper.report(x_hat_batch_val, total_loss_batch_val,z_val=z_batch_val) optim.global_step = sess.run(global_step) optim.A = A_val optim.y_batch = y_batch_val optim.i=i optim.j=j optim.z_batch= z_batch_val optim.best_keeper=best_keeper optim.bs=bs optim.init_obj = init_obj utils.save_to_pickle(optim,utils.get_checkpoint_dir(hparams, hparams.model_types[0])+'tmp/optim.pkl') print('Checkpoint of optimization created') hparams.optim.j = 0 x_hat_batch_val, total_loss_batch_val, z_batch_val = sess.run([x_hat_batch, total_loss_batch,z_batch], feed_dict=feed_dict) best_keeper.report(x_hat_batch_val, total_loss_batch_val,z_val=z_batch_val) if flag: break tf.reset_default_graph() return best_keeper.get_best()
def dcgan_estimator(hparams): # pylint: disable = C0326 # Get a session sess = tf.Session() # Set up palceholders A = tf.placeholder(tf.float32, shape=(hparams.n_input, hparams.num_measurements), name='A') y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.num_measurements), name='y_batch') # Create the generator z_batch = tf.Variable(tf.random_normal([hparams.batch_size, 100]), name='z_batch') x_hat_batch, restore_dict_gen, restore_path_gen = celebA_model_def.dcgan_gen( z_batch, hparams) # Create the discriminator prob, restore_dict_discrim, restore_path_discrim = celebA_model_def.dcgan_discrim( x_hat_batch, hparams) # measure the estimate if hparams.measurement_type == 'project': y_hat_batch = tf.identity(x_hat_batch, name='y2_batch') else: measurement_is_sparse = (hparams.measurement_type in ['inpaint', 'superres']) y_hat_batch = tf.matmul(x_hat_batch, A, b_is_sparse=measurement_is_sparse, name='y2_batch') # define all losses m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1) m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1) zp_loss_batch = tf.reduce_sum(z_batch**2, 1) d_loss1_batch = -tf.log(prob) d_loss2_batch = tf.log(1 - prob) # define total loss total_loss_batch = hparams.mloss1_weight * m_loss1_batch \ + hparams.mloss2_weight * m_loss2_batch \ + hparams.zprior_weight * zp_loss_batch \ + hparams.dloss1_weight * d_loss1_batch \ + hparams.dloss2_weight * d_loss2_batch total_loss = tf.reduce_mean(total_loss_batch) # Compute means for logging m_loss1 = tf.reduce_mean(m_loss1_batch) m_loss2 = tf.reduce_mean(m_loss2_batch) zp_loss = tf.reduce_mean(zp_loss_batch) d_loss1 = tf.reduce_mean(d_loss1_batch) d_loss2 = tf.reduce_mean(d_loss2_batch) # Set up gradient descent var_list = [z_batch] global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = utils.get_learning_rate(global_step, hparams) with tf.variable_scope(tf.get_variable_scope(), reuse=False): opt = utils.get_optimizer(learning_rate, hparams) update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op') opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step) # Intialize and restore model parameters init_op = tf.global_variables_initializer() sess.run(init_op) restorer_gen = tf.train.Saver(var_list=restore_dict_gen) restorer_discrim = tf.train.Saver(var_list=restore_dict_discrim) restorer_gen.restore(sess, restore_path_gen) restorer_discrim.restore(sess, restore_path_discrim) def estimator(A_val, y_batch_val, hparams): """Function that returns the estimated image""" best_keeper = utils.BestKeeper(hparams) if hparams.measurement_type == 'project': feed_dict = {y_batch: y_batch_val} else: feed_dict = {A: A_val, y_batch: y_batch_val} for i in range(hparams.num_random_restarts): sess.run(opt_reinit_op) for j in range(hparams.max_update_iter): if hparams.gif and ((j % hparams.gif_iter) == 0): images = sess.run(x_hat_batch, feed_dict=feed_dict) for im_num, image in enumerate(images): save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num) utils.set_up_dir(save_dir) save_path = save_dir + '{0}.png'.format(j) image = image.reshape(hparams.image_shape) save_image(image, save_path) _, lr_val, total_loss_val, \ m_loss1_val, \ m_loss2_val, \ zp_loss_val, \ d_loss1_val, \ d_loss2_val = sess.run([update_op, learning_rate, total_loss, m_loss1, m_loss2, zp_loss, d_loss1, d_loss2], feed_dict=feed_dict) logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {} d_loss1 {} d_loss2 {}' print logging_format.format(i, j, lr_val, total_loss_val, m_loss1_val, m_loss2_val, zp_loss_val, d_loss1_val, d_loss2_val) x_hat_batch_val, total_loss_batch_val = sess.run( [x_hat_batch, total_loss_batch], feed_dict=feed_dict) best_keeper.report(x_hat_batch_val, total_loss_batch_val) return best_keeper.get_best() return estimator
def pggan_estimator(hparams): # pylint: disable = C0326 # Get a session sess = tf.Session() # Set up palceholders Tx = tf.placeholder(tf.float32, shape=hparams.modSignal_shape, name='Tx') Rx = tf.placeholder(tf.float32, shape=hparams.modSignal_shape, name='Rx') Pilot = tf.placeholder(tf.float32, shape=[hparams.batch_size, hparams.pilot_dim], name='Pilot') # Create the generator z_batch = tf.Variable(tf.random.normal([hparams.batch_size, hparams.z_dim]), name='z_batch') H_hat, restore_dict_gen, restore_path_gen = channel_model_def.pggan_gen( z_batch, Pilot, hparams) # measure the estimate print('H_hat:', H_hat.shape) print('Tx:', Tx.shape) Rx_hat = utils.calRx(H_hat, Tx, hparams) ''' if hparams.measurement_type == 'project': y_hat_batch = tf.identity(x_hat_batch, name='y2_batch') elif hparams.measurement_type == 'pilot': Rx_hat = utils.calRx(H_hat,Tx,hparams) # Rx_hat = utils.multiComplex(H_hat,Tx); # Rx_hat = tf.multiply(H_hat, Tx, name='y_hat') # TODO complex mult else: measurement_is_sparse = (hparams.measurement_type in ['inpaint', 'superres']) y_hat_batch = tf.matmul(x_hat_batch, A, b_is_sparse=measurement_is_sparse, name='y2_batch') ''' # define all losses if hparams.measurement_type == 'pilot': # only polit Loss m_loss1_batch = tf.abs( utils.get_tf_pilot(Rx) - utils.get_tf_pilot(Rx_hat)) m_loss2_batch = (utils.get_tf_pilot(Rx) - utils.get_tf_pilot(Rx_hat))**2 zp_loss_batch = tf.reduce_sum(z_batch**2, 1) else: m_loss1_batch = tf.reduce_mean(tf.abs(Rx - Rx_hat), 1) m_loss2_batch = tf.reduce_mean((Rx - Rx_hat)**2, 1) zp_loss_batch = tf.reduce_sum(z_batch**2, 1) # define total loss total_loss_batch = hparams.mloss1_weight * m_loss1_batch \ + hparams.mloss2_weight * m_loss2_batch \ + hparams.zprior_weight * zp_loss_batch total_loss = tf.reduce_mean(total_loss_batch) # Compute means for logging m_loss1 = tf.reduce_mean(m_loss1_batch) m_loss2 = tf.reduce_mean(m_loss2_batch) zp_loss = tf.reduce_mean(zp_loss_batch) # Set up gradient descent var_list = [z_batch] global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = utils.get_learning_rate(global_step, hparams) with tf.variable_scope(tf.get_variable_scope(), reuse=False): opt = utils.get_optimizer(learning_rate, hparams) update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op') opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step) # Intialize and restore model parameters init_op = tf.global_variables_initializer() sess.run(init_op) restorer_gen = tf.train.Saver(var_list=restore_dict_gen) restorer_gen.restore(sess, restore_path_gen) def estimator(Tx_val, Rx_val, Pilot_val, hparams): """Function that returns the estimated image""" best_keeper = utils.BestKeeper(hparams) if hparams.measurement_type == 'project': feed_dict = {y_batch: y_batch_val} else: feed_dict = {Tx: Tx_val, Rx: Rx_val, Pilot: Pilot_val} for i in range(hparams.num_random_restarts): sess.run(opt_reinit_op) for j in range(hparams.max_update_iter): if hparams.gif and ((j % hparams.gif_iter) == 0): images = sess.run(x_hat_batch, feed_dict=feed_dict) for im_num, image in enumerate(images): save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num) utils.set_up_dir(save_dir) save_path = save_dir + '{0}.png'.format(j) image = image.reshape(hparams.image_shape) save_image(image, save_path) _, lr_val, total_loss_val, \ m_loss1_val, \ m_loss2_val, \ zp_loss_val = sess.run([update_op, learning_rate, total_loss, m_loss1, m_loss2, zp_loss], feed_dict=feed_dict) logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}' print logging_format.format(i, j, lr_val, total_loss_val, m_loss1_val, m_loss2_val, zp_loss_val) H_hat_val, total_loss_val = sess.run([H_hat, total_loss], feed_dict=feed_dict) best_keeper.report(H_hat_val, total_loss_val) return best_keeper.get_best() return estimator
def vae_estimator(hparams): # Get a session sess = tf.Session() # Set up palceholders Tx = tf.placeholder(tf.float32, shape=hparams.image_shape, name='Tx') Rx = tf.placeholder(tf.float32, shape=hparams.image_shape, name='Rx') # Create the generator # TODO: Move z_batch definition here z_batch, H_hat, restore_path, restore_dict = channel_model_def.vae_gen( hparams) # measure the estimate if hparams.measurement_type == 'project': Rx_hat = tf.identity(x_hat_batch, name='y_hat_batch') elif hparams.measurement_type == 'pilot': Rx_hat = utils.multiComplex(H_hat, Tx) # Rx_hat = tf.multiply(H_hat, Tx, name='y_hat') # TODO complex mult else: Rx_hat = tf.multiply(H_hat, Tx, name='y_hat') # define all losses m_loss1_batch = tf.reduce_mean(tf.reduce_mean(tf.abs(Rx - Rx_hat), 1), 0) m_loss2_batch = tf.reduce_mean(tf.reduce_mean((Rx - Rx_hat)**2, 1), 0) zp_loss_batch = tf.reduce_sum(z_batch**2, 1) # define total loss total_loss_batch = hparams.mloss1_weight * m_loss1_batch \ + hparams.mloss2_weight * m_loss2_batch \ + hparams.zprior_weight * zp_loss_batch total_loss = tf.reduce_mean(total_loss_batch) # Compute means for logging m_loss1 = tf.reduce_mean(m_loss1_batch) m_loss2 = tf.reduce_mean(m_loss2_batch) zp_loss = tf.reduce_mean(zp_loss_batch) # Set up gradient descent var_list = [z_batch] global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = utils.get_learning_rate(global_step, hparams) opt = utils.get_optimizer(learning_rate, hparams) update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op') opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step) # Intialize and restore model parameters init_op = tf.global_variables_initializer() sess.run(init_op) restorer = tf.train.Saver(var_list=restore_dict) restorer.restore(sess, restore_path) def estimator(Tx_val, Rx_val, hparams): """Function that returns the estimated image""" best_keeper = utils.BestKeeper(hparams) if hparams.measurement_type == 'project': feed_dict = {Rx: Rx_val} else: feed_dict = {Tx: Tx_val, Rx: Rx_val} for i in range(hparams.num_random_restarts): sess.run(opt_reinit_op) for j in range(hparams.max_update_iter): _, lr_val, total_loss_val, \ m_loss1_val, \ m_loss2_val, \ zp_loss_val = sess.run([update_op, learning_rate, total_loss, m_loss1, m_loss2, zp_loss], feed_dict=feed_dict) logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}' print logging_format.format(i, j, lr_val, total_loss_val, m_loss1_val, m_loss2_val, zp_loss_val) H_hat_val, total_loss_batch_val = sess.run( [H_hat, total_loss_batch], feed_dict=feed_dict) best_keeper.report(H_hat_val, total_loss_batch_val) return best_keeper.get_best() return estimator