def sample_style(self, sess, plot_size, n_sample=10, file_id=None): # epochs_completed, batch_size = dataflow.epochs_completed, dataflow.batch_size # dataflow.setup(epoch_val=0, batch_size=n_sample) # batch_data = dataflow.next_batch_dict() # latent_var = sess.run( # self._latent_op, # feed_dict={self._g_model.encoder_in: batch_data['im'], # self._g_model.keep_prob: 1.}) # label = [] # for i in range(n_labels): # label.extend([i for k in range(n_sample)]) # code = np.tile(latent_var, [n_labels, 1]) # [n_class*10, n_code] # print(batch_data['label']) gen_im = sess.run(self._g_model.layers['generate'], feed_dict={ # self._g_model.image: batch_data['im'], # self._g_model.label: label, # self._g_model.keep_prob: 1. }) if self._save_path: if file_id is not None: im_save_path = os.path.join( self._save_path, 'sample_style_{}.png'.format(file_id)) else: im_save_path = os.path.join( self._save_path, 'sample_style.png') n_sample = len(gen_im) plot_size = int(min(plot_size, math.sqrt(n_sample))) viz.viz_batch_im(batch_im=gen_im, grid_size=[plot_size, plot_size], save_path=im_save_path, gap=0, gap_color=0, shuffle=False)
def viz_generate_step(self, sess, save_path): batch_step = sess.run(self.layers['gen_step']) step_gen_im = np.vstack(batch_step) viz.viz_batch_im(batch_im=step_gen_im * 255., grid_size=[10, 10], save_path='{}/generate_step.png'.format(save_path), is_transpose=True)
def valid_epoch(self, sess, dataflow=None, moniter_generation=False, summary_writer=None): # self._g_model.set_is_training(True) # display_name_list = ['loss'] # cur_summary = None dataflow.setup(epoch_val=0, batch_size=dataflow.batch_size) display_name_list = ['loss'] step = 0 loss_sum = 0 while dataflow.epochs_completed == 0: step += 1 batch_data = dataflow.next_batch_dict() im = batch_data['im'] label = batch_data['label'] loss, valid_summary = sess.run( [self._loss_op, self._valid_summary_op], feed_dict={ self._t_model.encoder_in: im, self._t_model.image: im, self._t_model.keep_prob: 1.0, self._t_model.label: label, }) loss_sum += loss print('[Valid]: ', end='') display(self.global_step, step, [loss_sum], display_name_list, 'valid', summary_val=None, summary_writer=summary_writer) dataflow.setup(epoch_val=0, batch_size=dataflow.batch_size) gen_im = sess.run(self._generate_op) if moniter_generation and self._save_path: im_save_path = os.path.join( self._save_path, 'generate_step_{}.png'.format(self.global_step)) viz.viz_batch_im(batch_im=gen_im, grid_size=[10, 10], save_path=im_save_path, gap=0, gap_color=0, shuffle=False) if summary_writer: cur_summary = sess.run(self._generate_summary_op) summary_writer.add_summary(cur_summary, self.global_step) summary_writer.add_summary(valid_summary, self.global_step)
def viz_generate_step(self, sess, save_path, is_animation=False, file_id=None): batch_step = sess.run(self.layers['gen_step']) if not is_animation: step_gen_im = np.vstack(batch_step) # print(step_gen_im.shape) if file_id is None: save_name = '{}/generate_step.png'.format(save_path) else: save_name = '{}/generate_step_{}.png'.format( save_path, file_id) viz.viz_batch_im(batch_im=np.clip(step_gen_im * 255., 0., 255.), grid_size=[10, self._n_step], save_path=save_name, is_transpose=True) else: import imageio image_list = [] bsize = batch_step[0].shape[0] grid_size = int(bsize**0.5) for step_id, batch_im in enumerate(batch_step): if file_id is None: save_name = '{}/generate_step_{}.png'.format( save_path, step_id) else: save_name = '{}/generate_step_{}_{}.png'.format( save_path, file_id, step_id) merge_im = viz.viz_batch_im(batch_im=np.clip( batch_im * 255., 0., 255.), grid_size=[grid_size, grid_size], save_path=None, is_transpose=False) image_list.append(np.squeeze(merge_im)) if file_id is None: save_name = '{}/draw_generation.gif'.format(save_path) else: save_name = '{}/draw_generation_{}.gif'.format( save_path, file_id) imageio.mimsave(save_name, image_list)
def viz_samples(self, sess, random_code, plot_size, file_id=None): gen_im = sess.run(self._generate_op, feed_dict={self._g_model.z: random_code}) if self._save_path: if file_id is not None: im_save_path = os.path.join( self._save_path, 'generate_im_{}.png'.format(file_id)) else: im_save_path = os.path.join( self._save_path, 'generate_im.png') n_sample = len(gen_im) plot_size = int(min(plot_size, math.sqrt(n_sample))) viz.viz_batch_im(batch_im=gen_im, grid_size=[plot_size, plot_size], save_path=im_save_path, gap=0, gap_color=0, shuffle=False)
def _viz_samples(self, sess, random_vec, code_discrete, code_cont, keep_prob, plot_size=10, save_path=None, file_name='generate_im', file_id=None): """ Sample and save images from model as one single image. Args: sess (tf.Session): tensorflow session random_vec (float): list of input random vectors code_discrete (list): list of discrete code (one hot vectors) code_cont (list): list of continuous code keep_prob (float): keep probability for dropout plot_size (int): side size (number of samples) of saving image save_path (str): directory for saving image file_name (str): name for saving image file_id (int): index for saving image """ if save_path: plot_size = utils.get_shape2D(plot_size) gen_im = sess.run(self.generate_op, feed_dict={ self.random_vec: random_vec, self.keep_prob: keep_prob, self.code_discrete: code_discrete, self.code_continuous: code_cont }) if file_id is not None: im_save_path = os.path.join( save_path, '{}_{}.png'.format(file_name, file_id)) else: im_save_path = os.path.join(save_path, '{}.png'.format(file_name)) viz.viz_batch_im(batch_im=gen_im, grid_size=plot_size, save_path=im_save_path, gap=0, gap_color=0, shuffle=False)
def _viz_samples(self, sess, random_vec, plot_size=10, file_name='generate_im', file_id=None): """ Sample and save images from model as one single image. Args: sess (tf.Session): tensorflow session random_vec (float): list of input random vectors plot_size (int): side size (number of samples) of saving image file_name (str): name for saving image file_id (int): index for saving image """ plot_size = utils.get_shape2D(plot_size) gen_im = sess.run(self._generate_op, feed_dict={ self._g_model.random_vec: random_vec, self._g_model.keep_prob: self._keep_prob }) if self._save_path: if file_id is not None: im_save_path = os.path.join( self._save_path, '{}_{}.png'.format(file_name, file_id)) else: im_save_path = os.path.join(self._save_path, '{}.png'.format(file_name)) # n_sample = len(gen_im) # plot_size = int(min(plot_size, math.sqrt(n_sample))) viz.viz_batch_im(batch_im=gen_im, grid_size=plot_size, save_path=im_save_path, gap=0, gap_color=0, shuffle=False)
def viz_ranking(query_embedding, gallery_embedding, query_file_name, gallery_file_name, top_k=10, is_re_ranking=True, data_dir=None, save_path=None, is_viz=False): if is_re_ranking: q_g_dist = infertool.pair_distance(query_embedding, gallery_embedding) q_q_dist = infertool.pair_distance(query_embedding, query_embedding) g_g_dist = infertool.pair_distance(gallery_embedding, gallery_embedding) pair_dist = re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3) else: pair_dist = infertool.pair_distance(query_embedding, gallery_embedding) ranking_file_mat = infertool.ranking_distance(pair_dist, gallery_file_name, top_k=top_k) frame_width = 2 frame_color_correct = [0, 255, 0] frame_color_wrong = [255, 0, 0] if is_viz: assert data_dir and save_path for idx, q_im in enumerate(query_file_name): im_list = [] head, q_file_name = ntpath.split(q_im) q_class_id = q_file_name.split('_')[0] im = imageio.imread(os.path.join(data_dir, q_file_name), as_gray=False, pilmode="RGB") im = viz.add_frame_im(im, frame_width, frame_color=0) im_list.append(im) for g_im in ranking_file_mat[idx]: head, g_file_name = ntpath.split(g_im) g_class_id = g_file_name.split('_')[0] im = imageio.imread(os.path.join(data_dir, g_file_name), as_gray=False, pilmode="RGB") if g_class_id == q_class_id: frame_color = frame_color_correct else: frame_color = frame_color_wrong im = viz.add_frame_im(im, frame_width, frame_color=frame_color) im_list.append(im) viz.viz_batch_im(im_list, grid_size=[1, 1 + top_k], save_path=os.path.join( save_path, 'query_{}.png'.format(idx)), gap=0, gap_color=0, shuffle=False) return query_file_name, ranking_file_mat