Exemplo n.º 1
0
    def sample_style(self, sess, plot_size, n_sample=10, file_id=None):
        # epochs_completed, batch_size = dataflow.epochs_completed, dataflow.batch_size 
        # dataflow.setup(epoch_val=0, batch_size=n_sample)

        # batch_data = dataflow.next_batch_dict()
        # latent_var = sess.run(
        #     self._latent_op, 
        #     feed_dict={self._g_model.encoder_in: batch_data['im'],
        #                self._g_model.keep_prob: 1.})

        # label = []
        # for i in range(n_labels):
        #     label.extend([i for k in range(n_sample)])
        # code = np.tile(latent_var, [n_labels, 1]) # [n_class*10, n_code]
        # print(batch_data['label'])
        gen_im = sess.run(self._g_model.layers['generate'],
                          feed_dict={
                                     # self._g_model.image: batch_data['im'],
                                     # self._g_model.label: label,
                                     # self._g_model.keep_prob: 1.
                                     })

        if self._save_path:
            if file_id is not None:
                im_save_path = os.path.join(
                    self._save_path, 'sample_style_{}.png'.format(file_id))
            else:
                im_save_path = os.path.join(
                    self._save_path, 'sample_style.png')

            n_sample = len(gen_im)
            plot_size = int(min(plot_size, math.sqrt(n_sample)))
            viz.viz_batch_im(batch_im=gen_im, grid_size=[plot_size, plot_size],
                             save_path=im_save_path, gap=0, gap_color=0,
                             shuffle=False)
    def viz_generate_step(self, sess, save_path):
        batch_step = sess.run(self.layers['gen_step'])
        step_gen_im = np.vstack(batch_step)

        viz.viz_batch_im(batch_im=step_gen_im * 255.,
                         grid_size=[10, 10],
                         save_path='{}/generate_step.png'.format(save_path),
                         is_transpose=True)
Exemplo n.º 3
0
    def valid_epoch(self,
                    sess,
                    dataflow=None,
                    moniter_generation=False,
                    summary_writer=None):
        # self._g_model.set_is_training(True)
        # display_name_list = ['loss']
        # cur_summary = None

        dataflow.setup(epoch_val=0, batch_size=dataflow.batch_size)
        display_name_list = ['loss']

        step = 0
        loss_sum = 0
        while dataflow.epochs_completed == 0:
            step += 1

            batch_data = dataflow.next_batch_dict()
            im = batch_data['im']
            label = batch_data['label']
            loss, valid_summary = sess.run(
                [self._loss_op, self._valid_summary_op],
                feed_dict={
                    self._t_model.encoder_in: im,
                    self._t_model.image: im,
                    self._t_model.keep_prob: 1.0,
                    self._t_model.label: label,
                })
            loss_sum += loss

        print('[Valid]: ', end='')
        display(self.global_step,
                step, [loss_sum],
                display_name_list,
                'valid',
                summary_val=None,
                summary_writer=summary_writer)
        dataflow.setup(epoch_val=0, batch_size=dataflow.batch_size)

        gen_im = sess.run(self._generate_op)
        if moniter_generation and self._save_path:
            im_save_path = os.path.join(
                self._save_path,
                'generate_step_{}.png'.format(self.global_step))
            viz.viz_batch_im(batch_im=gen_im,
                             grid_size=[10, 10],
                             save_path=im_save_path,
                             gap=0,
                             gap_color=0,
                             shuffle=False)
        if summary_writer:
            cur_summary = sess.run(self._generate_summary_op)
            summary_writer.add_summary(cur_summary, self.global_step)
            summary_writer.add_summary(valid_summary, self.global_step)
    def viz_generate_step(self,
                          sess,
                          save_path,
                          is_animation=False,
                          file_id=None):
        batch_step = sess.run(self.layers['gen_step'])

        if not is_animation:
            step_gen_im = np.vstack(batch_step)
            # print(step_gen_im.shape)

            if file_id is None:
                save_name = '{}/generate_step.png'.format(save_path)
            else:
                save_name = '{}/generate_step_{}.png'.format(
                    save_path, file_id)

            viz.viz_batch_im(batch_im=np.clip(step_gen_im * 255., 0., 255.),
                             grid_size=[10, self._n_step],
                             save_path=save_name,
                             is_transpose=True)

        else:
            import imageio

            image_list = []
            bsize = batch_step[0].shape[0]
            grid_size = int(bsize**0.5)
            for step_id, batch_im in enumerate(batch_step):
                if file_id is None:
                    save_name = '{}/generate_step_{}.png'.format(
                        save_path, step_id)
                else:
                    save_name = '{}/generate_step_{}_{}.png'.format(
                        save_path, file_id, step_id)

                merge_im = viz.viz_batch_im(batch_im=np.clip(
                    batch_im * 255., 0., 255.),
                                            grid_size=[grid_size, grid_size],
                                            save_path=None,
                                            is_transpose=False)

                image_list.append(np.squeeze(merge_im))

            if file_id is None:
                save_name = '{}/draw_generation.gif'.format(save_path)
            else:
                save_name = '{}/draw_generation_{}.gif'.format(
                    save_path, file_id)
            imageio.mimsave(save_name, image_list)
Exemplo n.º 5
0
    def viz_samples(self, sess, random_code, plot_size, file_id=None):
        gen_im = sess.run(self._generate_op, feed_dict={self._g_model.z: random_code})
        if self._save_path:
            if file_id is not None:
                im_save_path = os.path.join(
                    self._save_path, 'generate_im_{}.png'.format(file_id))
            else:
                im_save_path = os.path.join(
                    self._save_path, 'generate_im.png')

            n_sample = len(gen_im)
            plot_size = int(min(plot_size, math.sqrt(n_sample)))
            viz.viz_batch_im(batch_im=gen_im, grid_size=[plot_size, plot_size],
                            save_path=im_save_path, gap=0, gap_color=0,
                            shuffle=False)
Exemplo n.º 6
0
    def _viz_samples(self,
                     sess,
                     random_vec,
                     code_discrete,
                     code_cont,
                     keep_prob,
                     plot_size=10,
                     save_path=None,
                     file_name='generate_im',
                     file_id=None):
        """ Sample and save images from model as one single image.

        Args:
            sess (tf.Session): tensorflow session
            random_vec (float): list of input random vectors
            code_discrete (list): list of discrete code (one hot vectors)
            code_cont (list): list of continuous code
            keep_prob (float): keep probability for dropout
            plot_size (int): side size (number of samples) of saving image
            save_path (str): directory for saving image
            file_name (str): name for saving image
            file_id (int): index for saving image 
        """
        if save_path:
            plot_size = utils.get_shape2D(plot_size)
            gen_im = sess.run(self.generate_op,
                              feed_dict={
                                  self.random_vec: random_vec,
                                  self.keep_prob: keep_prob,
                                  self.code_discrete: code_discrete,
                                  self.code_continuous: code_cont
                              })

            if file_id is not None:
                im_save_path = os.path.join(
                    save_path, '{}_{}.png'.format(file_name, file_id))
            else:
                im_save_path = os.path.join(save_path,
                                            '{}.png'.format(file_name))

            viz.viz_batch_im(batch_im=gen_im,
                             grid_size=plot_size,
                             save_path=im_save_path,
                             gap=0,
                             gap_color=0,
                             shuffle=False)
Exemplo n.º 7
0
    def _viz_samples(self,
                     sess,
                     random_vec,
                     plot_size=10,
                     file_name='generate_im',
                     file_id=None):
        """ Sample and save images from model as one single image.

        Args:
            sess (tf.Session): tensorflow session
            random_vec (float): list of input random vectors
            plot_size (int): side size (number of samples) of saving image
            file_name (str): name for saving image
            file_id (int): index for saving image 
        """
        plot_size = utils.get_shape2D(plot_size)
        gen_im = sess.run(self._generate_op,
                          feed_dict={
                              self._g_model.random_vec: random_vec,
                              self._g_model.keep_prob: self._keep_prob
                          })
        if self._save_path:
            if file_id is not None:
                im_save_path = os.path.join(
                    self._save_path, '{}_{}.png'.format(file_name, file_id))
            else:
                im_save_path = os.path.join(self._save_path,
                                            '{}.png'.format(file_name))

            # n_sample = len(gen_im)
            # plot_size = int(min(plot_size, math.sqrt(n_sample)))
            viz.viz_batch_im(batch_im=gen_im,
                             grid_size=plot_size,
                             save_path=im_save_path,
                             gap=0,
                             gap_color=0,
                             shuffle=False)
Exemplo n.º 8
0
def viz_ranking(query_embedding,
                gallery_embedding,
                query_file_name,
                gallery_file_name,
                top_k=10,
                is_re_ranking=True,
                data_dir=None,
                save_path=None,
                is_viz=False):

    if is_re_ranking:
        q_g_dist = infertool.pair_distance(query_embedding, gallery_embedding)
        q_q_dist = infertool.pair_distance(query_embedding, query_embedding)
        g_g_dist = infertool.pair_distance(gallery_embedding,
                                           gallery_embedding)
        pair_dist = re_ranking(q_g_dist,
                               q_q_dist,
                               g_g_dist,
                               k1=20,
                               k2=6,
                               lambda_value=0.3)
    else:
        pair_dist = infertool.pair_distance(query_embedding, gallery_embedding)

    ranking_file_mat = infertool.ranking_distance(pair_dist,
                                                  gallery_file_name,
                                                  top_k=top_k)

    frame_width = 2
    frame_color_correct = [0, 255, 0]
    frame_color_wrong = [255, 0, 0]

    if is_viz:
        assert data_dir and save_path
        for idx, q_im in enumerate(query_file_name):
            im_list = []
            head, q_file_name = ntpath.split(q_im)
            q_class_id = q_file_name.split('_')[0]
            im = imageio.imread(os.path.join(data_dir, q_file_name),
                                as_gray=False,
                                pilmode="RGB")
            im = viz.add_frame_im(im, frame_width, frame_color=0)
            im_list.append(im)

            for g_im in ranking_file_mat[idx]:
                head, g_file_name = ntpath.split(g_im)
                g_class_id = g_file_name.split('_')[0]
                im = imageio.imread(os.path.join(data_dir, g_file_name),
                                    as_gray=False,
                                    pilmode="RGB")
                if g_class_id == q_class_id:
                    frame_color = frame_color_correct
                else:
                    frame_color = frame_color_wrong

                im = viz.add_frame_im(im, frame_width, frame_color=frame_color)
                im_list.append(im)

            viz.viz_batch_im(im_list,
                             grid_size=[1, 1 + top_k],
                             save_path=os.path.join(
                                 save_path, 'query_{}.png'.format(idx)),
                             gap=0,
                             gap_color=0,
                             shuffle=False)

    return query_file_name, ranking_file_mat