コード例 #1
0
ファイル: GazeGAN.py プロジェクト: bensewell/GazeCorrection
    def test(self):

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init)
            ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir)
            print('Load checkpoint')
            if ckpt and ckpt.model_checkpoint_path:
                self.saver.restore(sess, ckpt.model_checkpoint_path)
                print('Load Succeed!')
            else:
                print('Do not exists any checkpoint,Load Failed!')
                exit()

            _, _, _, testbatch, testmask = self.dataset.input()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            batch_num = 1000 / self.opt.batch_size
            for j in range(batch_num):
                real_test_batch, real_eye_pos = sess.run([testbatch, testmask])
                batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(
                    real_eye_pos)
                f_d = {
                    self.x: real_test_batch,
                    self.xm: batch_masks,
                    self.x_left_p: batch_left_eye_pos,
                    self.x_right_p: batch_right_eye_pos
                }

                for i in range(self.opt.batch_size):

                    output = sess.run([self.x, self.y], feed_dict=f_d)
                    output_concat = self.Transpose(
                        np.array([output[0], output[1]]))
                    save_images(
                        np.reshape(
                            output_concat, '{}/{:02d}_{:2d}.jpg'.format(
                                self.opt.test_sample_dir, j, i)))

            coord.request_stop()
            coord.join(threads)
コード例 #2
0
    def test(self):

        self.saver = tf.train.Saver()
        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir)
            if ckpt and ckpt.model_checkpoint_path:
                self.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print("Do not found the pretrained model")

            print("Start read dataset")
            _, _, te_img, te_label = self.dataset.input()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            print("Start entering the looping")
            for i in range(2000):
                print(i)

                _te_img, _te_label = sess.run([te_img, te_label])
                f_d = {self.x: _te_img, self.label: _te_label}
                te_o = sess.run([self.x, self._x], feed_dict=f_d)
                te_x_list = np.split(te_o[1],
                                     indices_or_sections=self.opt.n_att * 2,
                                     axis=-1)
                te_x_list.insert(0, te_o[0])
                _te_o = self.Transpose(te_x_list)
                save_images(
                    _te_o, '{}/{:02d}_te.jpg'.format(self.opt.test_sample_dir,
                                                     i))

            coord.request_stop()
            coord.join(threads)
コード例 #3
0
ファイル: GazeGAN.py プロジェクト: bensewell/GazeCorrection
    def train(self):

        self.t_vars = tf.trainable_variables()
        self.d_vars = [var for var in self.t_vars if 'D' in var.name]
        self.g_vars = [var for var in self.t_vars if 'G' in var.name]
        self.e_vars = [var for var in self.t_vars if 'encode' in var.name]
        assert len(self.t_vars) == len(self.d_vars + self.g_vars + self.e_vars)

        self.saver = tf.train.Saver()
        self.p_saver = tf.train.Saver(self.e_vars)

        opti_D = tf.train.AdamOptimizer(self.opt.lr_d * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2).\
                                        minimize(loss=self.D_loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(self.opt.lr_g * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2).\
                                        minimize(loss=self.G_loss, var_list=self.g_vars)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            start_step = 0
            ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir)
            if ckpt and ckpt.model_checkpoint_path:
                start_step = int(
                    ckpt.model_checkpoint_path.split('model_',
                                                     2)[1].split('.', 2)[0])
                self.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print("")
                # try:
                #     #self.p_saver.restore(sess, os.path.join(self.opt.pretrain_path,
                #     #                                           'model_{:06d}.ckpt'.format(100000)))
                # except:
                #     print(" Self-Guided Model path may not be correct")

            #summary_op = tf.summary.merge_all()
            #summary_writer = tf.summary.FileWriter(self.opt.log_dir, sess.graph)
            step = start_step
            lr_decay = 1

            print("Start read dataset")

            image_path, train_images, train_eye_pos, test_images, test_eye_pos = self.dataset.input(
            )
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            print("Start entering the looping")
            real_test_batch, real_test_pos = sess.run(
                [test_images, test_eye_pos])

            while step <= self.opt.niter:

                if step > 20000 and step % 2000 == 0:
                    lr_decay = (self.opt.niter - step) / float(self.opt.niter -
                                                               20000)

                real_batch_image_path, x_data, x_p_data = sess.run(
                    [image_path, train_images, train_eye_pos])
                xm_data, x_left_p_data, x_right_p_data = self.get_Mask_and_pos(
                    x_p_data)

                f_d = {
                    self.x: x_data,
                    self.xm: xm_data,
                    self.x_left_p: x_left_p_data,
                    self.x_right_p: x_right_p_data,
                    self.lr_decay: lr_decay
                }

                # optimize D
                sess.run(opti_D, feed_dict=f_d)
                # optimize G
                sess.run(opti_G, feed_dict=f_d)
                #summary_str = sess.run(summary_op, feed_dict=f_d)
                #summary_writer.add_summary(summary_str, step)
                if step % 500 == 0:

                    if self.opt.is_ss:
                        output_loss = sess.run([
                            self.D_loss, self.G_loss, self.opt.lam_r *
                            self.recon_loss, self.opt.lam_p * self.percep_loss,
                            self.r_cls_loss, self.f_cls_loss
                        ],
                                               feed_dict=f_d)
                        print(
                            "step %d D_loss=%.8f, G_loss=%.4f, Recon_loss=%.4f, Percep_loss=%.4f, "
                            "Real_class_loss=%.4f, Fake_class_loss=%.4f, lr_decay=%.4f"
                            % (step, output_loss[0], output_loss[1],
                               output_loss[2], output_loss[3], output_loss[4],
                               output_loss[5], lr_decay))
                    else:
                        output_loss = sess.run([
                            self.D_loss, self.G_loss, self.opt.lam_r *
                            self.recon_loss, self.opt.lam_p * self.percep_loss
                        ],
                                               feed_dict=f_d)
                        print(
                            "step %d D_loss=%.8f, G_loss=%.4f, Recon_loss=%.4f, Percep_loss=%.4f, lr_decay=%.4f"
                            % (step, output_loss[0], output_loss[1],
                               output_loss[2], output_loss[3], lr_decay))

                if np.mod(step, 2000) == 0:

                    train_output_img = sess.run([
                        self.xl_left, self.xl_right, self.xc, self.yo, self.y,
                        self.yl_left, self.yl_right
                    ],
                                                feed_dict=f_d)

                    batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(
                        real_test_pos)
                    #for test
                    f_d = {
                        self.x: real_test_batch,
                        self.xm: batch_masks,
                        self.x_left_p: batch_left_eye_pos,
                        self.x_right_p: batch_right_eye_pos,
                        self.lr_decay: lr_decay
                    }

                    test_output_img = sess.run([self.xc, self.yo, self.y],
                                               feed_dict=f_d)
                    output_concat = self.Transpose(
                        np.array([
                            x_data, train_output_img[2], train_output_img[3],
                            train_output_img[4]
                        ]))
                    local_output_concat = self.Transpose(
                        np.array([
                            train_output_img[0], train_output_img[1],
                            train_output_img[5], train_output_img[6]
                        ]))
                    test_output_concat = self.Transpose(
                        np.array([
                            real_test_batch, test_output_img[0],
                            test_output_img[2], test_output_img[1]
                        ]))
                    save_images(
                        local_output_concat,
                        '{}/{:02d}_local_output.jpg'.format(
                            self.opt.sample_dir, step))
                    save_images(
                        output_concat,
                        '{}/{:02d}_output.jpg'.format(self.opt.sample_dir,
                                                      step))
                    save_images(
                        test_output_concat, '{}/{:02d}_test_output.jpg'.format(
                            self.opt.sample_dir, step))

                if np.mod(step, 20000) == 0:
                    self.saver.save(
                        sess,
                        os.path.join(self.opt.checkpoints_dir,
                                     'model_{:06d}.ckpt'.format(step)))

                step += 1

            save_path = self.saver.save(
                sess,
                os.path.join(self.opt.checkpoints_dir,
                             'model_{:06d}.ckpt'.format(step)))
            #summary_writer.close()

            coord.request_stop()
            coord.join(threads)

            print("Model saved in file: %s" % save_path)
コード例 #4
0
    def train(self):

        self.t_vars = tf.trainable_variables()
        self.d_vars = [var for var in self.t_vars if 'D' in var.name]
        self.g_vars = [var for var in self.t_vars if 'G' in var.name]

        assert len(self.t_vars) == len(self.d_vars + self.g_vars)

        self.saver = tf.train.Saver()
        opti_D = tf.train.AdamOptimizer(self.opt.lr_d * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2).\
                                        minimize(loss=self.D_loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(self.opt.lr_g * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2).\
                                        minimize(loss=self.G_loss, var_list=self.g_vars)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir)
            if ckpt and ckpt.model_checkpoint_path:
                start_step = int(
                    ckpt.model_checkpoint_path.split('model_',
                                                     2)[1].split('.', 2)[0])
                self.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                start_step = 0

            step = start_step
            lr_decay = 1
            print("Start read dataset")

            tr_img, tr_label, te_img, te_label = self.dataset.input()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            _te_img, _te_label = sess.run([te_img, te_label])
            print("Start entering the looping")
            while step <= self.opt.niter:

                if step > 20000 and step % 2000 == 0:
                    lr_decay = (self.opt.niter - step) / float(self.opt.niter -
                                                               20000)

                _tr_img, _tr_label = sess.run([tr_img, tr_label])
                f_d = {
                    self.x: _tr_img,
                    self.label: _tr_label,
                    self.lr_decay: lr_decay
                }
                # optimize D
                sess.run(opti_D, feed_dict=f_d)
                # optimize G
                if step % self.opt.n_critic == 0:
                    sess.run(opti_G, feed_dict=f_d)

                if step % 500 == 0:

                    o_loss = sess.run([self.D_loss, self.G_loss],
                                      feed_dict=f_d)
                    print("step %d D_loss=%.8f, G_loss=%.4f lr_decay=%.4f" %
                          (step, o_loss[0], o_loss[1], lr_decay))

                if np.mod(step, 2000) == 0:

                    tr_o = sess.run([self.x, self._x], feed_dict=f_d)

                    f_d = {self.x: _te_img, self.label: _te_label}
                    te_o = sess.run([self.x, self._x], feed_dict=f_d)
                    tr_x_list = np.split(tr_o[1],
                                         indices_or_sections=self.opt.n_att *
                                         2,
                                         axis=-1)
                    te_x_list = np.split(te_o[1],
                                         indices_or_sections=self.opt.n_att *
                                         2,
                                         axis=-1)

                    tr_x_list.insert(0, tr_o[0])
                    te_x_list.insert(0, te_o[0])

                    _tr_o = self.Transpose(tr_x_list)
                    _te_o = self.Transpose(te_x_list)
                    save_images(
                        _tr_o,
                        '{}/{:02d}_tr.jpg'.format(self.opt.sample_dir, step))
                    save_images(
                        _te_o,
                        '{}/{:02d}_te.jpg'.format(self.opt.sample_dir, step))

                if np.mod(step, self.opt.save_model_freq) == 0:
                    self.saver.save(
                        sess,
                        os.path.join(self.opt.checkpoints_dir,
                                     'model_{:06d}.ckpt'.format(step)))

                step += 1

            save_path = self.saver.save(
                sess,
                os.path.join(self.opt.checkpoints_dir,
                             'model_{:06d}.ckpt'.format(step)))

            coord.request_stop()
            coord.join(threads)

            print("Model saved in file: %s" % save_path)
コード例 #5
0
    def train(self):

        opti_D = tf.train.AdamOptimizer(self.d_learning_rate * self.lr_decay,beta1=self.beta1, beta2=self.beta2).\
                                        minimize(loss=self.D_loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(self.g_learning_rate * self.lr_decay,beta1=self.beta1, beta2=self.beta2).\
                                        minimize(loss=self.G_loss, var_list=self.ed_vars)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)

            start_step = 0
            ckpt = tf.train.get_checkpoint_state(self.model_dir)
            if ckpt and ckpt.model_checkpoint_path:
                start_step = int(
                    ckpt.model_checkpoint_path.split('model_',
                                                     2)[1].split('.', 2)[0])
                self.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                try:
                    self.pretrain_saver.restore(
                        sess,
                        os.path.join(
                            self.pretrain_model_dir,
                            'model_{:06d}.ckpt'.format(
                                self.pretrain_model_index)))
                except:
                    print(" Self-Guided Model path may not be correct")

            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            step = start_step
            lr_decay = self.lr_init

            print("Start read dataset")

            image_path, train_images, train_eye_pos, test_images, test_eye_pos = self.dataset.input(
            )
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            print("Start entering the looping")
            real_test_batch, real_test_pos = sess.run(
                [test_images, test_eye_pos])

            while step <= self.max_iters:

                if step > 20000 and step % 2000 == 0:
                    lr_decay = (self.max_iters - step) / float(self.max_iters -
                                                               20000)

                real_batch_image_path, real_batch_image, real_eye_pos = sess.run(
                    [image_path, train_images, train_eye_pos])
                batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(
                    real_eye_pos)

                f_d = {
                    self.input: real_batch_image,
                    self.mask: batch_masks,
                    self.input_left_labels: batch_left_eye_pos,
                    self.input_right_labels: batch_right_eye_pos,
                    self.lr_decay: lr_decay
                }

                # optimize D
                sess.run(opti_D, feed_dict=f_d)
                # optimize G
                sess.run(opti_G, feed_dict=f_d)

                summary_str = sess.run(summary_op, feed_dict=f_d)
                summary_writer.add_summary(summary_str, step)

                if step % 500 == 0:

                    if self.is_supervised:

                        output_loss = sess.run([
                            self.D_loss, self.G_loss,
                            self.lam_recon * self.recon_loss,
                            self.lam_percep * self.percep_loss,
                            self.real_class_loss, self.fake_class_loss
                        ],
                                               feed_dict=f_d)
                        print(
                            "step %d D_loss=%.8f, G_loss=%.4f, Recon_loss=%.4f, Percep_loss=%.4f, "
                            "Real_class_loss=%.4f, Fake_class_loss=%.4f, lr_decay=%.4f"
                            % (step, output_loss[0], output_loss[1],
                               output_loss[2], output_loss[3], output_loss[4],
                               output_loss[5], lr_decay))

                    else:

                        output_loss = sess.run([
                            self.D_loss, self.G_loss, self.lam_recon *
                            self.recon_loss, self.lam_percep * self.percep_loss
                        ],
                                               feed_dict=f_d)

                        print(
                            "step %d D_loss=%.8f, G_loss=%.4f, Recon_loss=%.4f, Percep_loss=%.4f, lr_decay=%.4f"
                            % (step, output_loss[0], output_loss[1],
                               output_loss[2], output_loss[3], lr_decay))

                if np.mod(step, 2000) == 0:

                    train_output_img = sess.run([
                        self.local_input_left, self.local_input_right,
                        self.incomplete_img, self.recon_img,
                        self.new_recon_img, self.local_recon_img_left,
                        self.local_recon_img_right
                    ],
                                                feed_dict=f_d)

                    batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(
                        real_test_pos)

                    #for test
                    f_d = {
                        self.input: real_test_batch,
                        self.mask: batch_masks,
                        self.input_left_labels: batch_left_eye_pos,
                        self.input_right_labels: batch_right_eye_pos,
                        self.lr_decay: lr_decay
                    }

                    test_output_img = sess.run([
                        self.incomplete_img, self.recon_img, self.new_recon_img
                    ],
                                               feed_dict=f_d)
                    output_concat = np.concatenate([
                        real_batch_image, train_output_img[2],
                        train_output_img[3], train_output_img[4]
                    ],
                                                   axis=0)
                    local_output_concat = np.concatenate([
                        train_output_img[0], train_output_img[1],
                        train_output_img[5], train_output_img[6]
                    ],
                                                         axis=0)
                    test_output_concat = np.concatenate([
                        real_test_batch, test_output_img[0],
                        test_output_img[2], test_output_img[1]
                    ],
                                                        axis=0)

                    save_images(
                        local_output_concat, [
                            local_output_concat.shape[0] / self.batch_size,
                            self.batch_size
                        ], '{}/{:02d}_local_output.jpg'.format(
                            self.sample_dir, step))
                    save_images(
                        output_concat, [
                            output_concat.shape[0] / self.batch_size,
                            self.batch_size
                        ],
                        '{}/{:02d}_output.jpg'.format(self.sample_dir, step))
                    save_images(
                        test_output_concat, [
                            test_output_concat.shape[0] / self.batch_size,
                            self.batch_size
                        ], '{}/{:02d}_test_output.jpg'.format(
                            self.sample_dir, step))

                if np.mod(step, 20000) == 0:
                    self.Inception_score(sess, test_images, test_eye_pos, step)
                    self.FID_score(sess, test_images, test_eye_pos, step)
                    self.saver.save(
                        sess,
                        os.path.join(self.model_dir,
                                     'model_{:06d}.ckpt'.format(step)))

                step += 1

            save_path = self.saver.save(
                sess,
                os.path.join(self.model_dir, 'model_{:06d}.ckpt'.format(step)))
            summary_writer.close()

            coord.request_stop()
            coord.join(threads)

            print("Model saved in file: %s" % save_path)
コード例 #6
0
    def train(self):

        self.t_vars = tf.trainable_variables()
        self.dx_vars = [var for var in self.t_vars if 'Dx' in var.name]
        self.dy_vars = [var for var in self.t_vars if 'Dy' in var.name]
        self.gx_vars = [var for var in self.t_vars if 'Gx' in var.name]
        self.gy_vars = [var for var in self.t_vars if 'Gy' in var.name]
        self.e_vars = [var for var in self.t_vars if 'encode' in var.name]
        self.gr_vars = [var for var in self.t_vars if 'Gr' in var.name]

        # assert len(self.t_vars) == len(self.dx_vars + self.dy_vars + self.gx_vars
        #                               + self.gy_vars + self.e_vars + self.gr_vars)

        self.saver = tf.train.Saver()
        self.p_saver = tf.train.Saver(self.gr_vars)
        opti_Dx = tf.train.AdamOptimizer(self.opt.lr_d * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2). \
            minimize(loss=self.Dx_loss, var_list=self.dx_vars)
        opti_Dy = tf.train.AdamOptimizer(self.opt.lr_d * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2). \
            minimize(loss=self.Dy_loss, var_list=self.dy_vars)
        opti_Gx = tf.train.AdamOptimizer(self.opt.lr_g * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2). \
            minimize(loss=self.Gx_loss, var_list=self.gx_vars)
        opti_Gy = tf.train.AdamOptimizer(self.opt.lr_g * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2). \
            minimize(loss=self.Gy_loss, var_list=self.gy_vars + self.e_vars)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            start_step = 0

            variables_to_restore = slim.get_variables_to_restore(
                include=['vgg_16'])
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, self.opt.vgg_path)

            ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir)
            if ckpt and ckpt.model_checkpoint_path:
                start_step = int(
                    ckpt.model_checkpoint_path.split('model_',
                                                     2)[1].split('.', 2)[0])
                self.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                try:
                    ckpt = tf.train.get_checkpoint_state(
                        self.opt.pretrain_path)
                    self.p_saver.restore(sess, ckpt.model_checkpoint_path)
                except:
                    print(" Self-Guided Model path may not be correct")
            # summary_op = tf.summary.merge_all()
            # summary_writer = tf.summary.FileWriter(self.opt.log_dir, sess.graph)
            step = start_step
            lr_decay = 1

            print("Start read dataset")
            train_images_x, train_eye_pos_x, train_images_y, train_eye_pos_y, \
            test_images, test_eye_pos = self.dataset.input()

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            real_test_batch, real_test_pos = sess.run(
                [test_images, test_eye_pos])

            while step <= self.opt.niter:

                if step > 20000 and step % 2000 == 0:
                    lr_decay = (self.opt.niter - step) / float(self.opt.niter -
                                                               20000)

                x_data, x_p_data = sess.run([train_images_x, train_eye_pos_x])
                y_data, y_p_data = sess.run([train_images_y, train_eye_pos_y])
                xm_data, x_left_p_data, x_right_p_data = self.get_Mask_and_pos(
                    x_p_data)
                ym_data, y_left_p_data, y_right_p_data = self.get_Mask_and_pos(
                    y_p_data)

                f_d = {
                    self.x: x_data,
                    self.xm: xm_data,
                    self.x_left_p: x_left_p_data,
                    self.x_right_p: x_right_p_data,
                    self.y: y_data,
                    self.ym: ym_data,
                    self.y_left_p: y_left_p_data,
                    self.y_right_p: y_right_p_data,
                    self.lr_decay: lr_decay
                }

                sess.run(opti_Dx, feed_dict=f_d)
                sess.run(opti_Dy, feed_dict=f_d)
                sess.run(opti_Gx, feed_dict=f_d)
                sess.run(opti_Gy, feed_dict=f_d)

                if step % 500 == 0:
                    output_loss = sess.run([
                        self.Dx_loss + self.Dy_loss, self.Gx_loss,
                        self.Gy_loss, self.opt.lam_r * self.recon_loss_x,
                        self.opt.lam_r * self.recon_loss_y
                    ],
                                           feed_dict=f_d)
                    print(
                        "step %d D_loss=%.4f, Gx_loss=%.4f, Gy_loss=%.4f, Recon_loss_x=%.4f, Recon_loss_y=%.4f, lr_decay=%.4f"
                        %
                        (step, output_loss[0], output_loss[1], output_loss[2],
                         output_loss[3], output_loss[4], lr_decay))

                if np.mod(step, 2000) == 0:
                    o_list = sess.run([
                        self.xl_left, self.xl_right, self.xc, self.xo,
                        self.yl_left, self.yl_right, self.yc, self.yo,
                        self.y2x, self.y2x_
                    ],
                                      feed_dict=f_d)

                    batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(
                        real_test_pos)
                    # for test
                    f_d = {
                        self.x: real_test_batch,
                        self.xm: batch_masks,
                        self.x_left_p: batch_left_eye_pos,
                        self.x_right_p: batch_right_eye_pos,
                        self.y: real_test_batch,
                        self.ym: batch_masks,
                        self.y_left_p: batch_left_eye_pos,
                        self.y_right_p: batch_right_eye_pos,
                        self.lr_decay: lr_decay
                    }

                    t_o_list = sess.run([self.xc, self.xo, self.yc, self.yo],
                                        feed_dict=f_d)
                    train_trans = self.Transpose(
                        np.array([
                            x_data, o_list[2], o_list[3], o_list[6], o_list[7],
                            o_list[8], o_list[9]
                        ]))
                    l_trans = self.Transpose(
                        np.array([o_list[0], o_list[1], o_list[4], o_list[5]]))
                    test_trans = self.Transpose(
                        np.array([
                            real_test_batch, t_o_list[0], t_o_list[1],
                            t_o_list[2], t_o_list[3]
                        ]))

                    save_images(
                        l_trans,
                        '{}/{:02d}_lo_{}.jpg'.format(self.opt.sample_dir, step,
                                                     self.opt.exper_name))
                    save_images(
                        train_trans,
                        '{}/{:02d}_tr_{}.jpg'.format(self.opt.sample_dir, step,
                                                     self.opt.exper_name))
                    save_images(
                        test_trans,
                        '{}/{:02d}_te_{}.jpg'.format(self.opt.sample_dir, step,
                                                     self.opt.exper_name))

                if np.mod(step, 20000) == 0:
                    self.saver.save(
                        sess,
                        os.path.join(self.opt.checkpoints_dir,
                                     'model_{:06d}.ckpt'.format(step)))

                step += 1

            save_path = self.saver.save(
                sess,
                os.path.join(self.opt.checkpoints_dir,
                             'model_{:06d}.ckpt'.format(step)))
            # summary_writer.close()

            coord.request_stop()
            coord.join(threads)

            print("Model saved in file: %s" % save_path)
コード例 #7
0
    def test(self):

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init)
            ckpt = tf.train.get_checkpoint_state(self.model_dir)
            print('Load checkpoint')
            if ckpt and ckpt.model_checkpoint_path:
                print ckpt.model_checkpoint_path
                self.saver.restore(sess, ckpt.model_checkpoint_path)
                print('Load Succeed!')
            else:
                print('Do not exists any checkpoint,Load Failed!')
                exit()

            _, _, _, testbatch, testmask = self.dataset.input()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            batch_num = 1000 / self.batch_size
            for j in range(batch_num):
                real_test_batch, real_eye_pos = sess.run([testbatch, testmask])
                batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(
                    real_eye_pos)
                f_d = {
                    self.input: real_test_batch,
                    self.mask: batch_masks,
                    self.input_left_labels: batch_left_eye_pos,
                    self.input_right_labels: batch_right_eye_pos
                }

                for i in range(self.batch_size):

                    output = sess.run([self.input, self.new_recon_img],
                                      feed_dict=f_d)

                    save_images(
                        np.reshape(output[0][i],
                                   newshape=(1, self.output_size,
                                             self.output_size, self.channel)),
                        [1, 1], '{}/{:02d}_{:2d}_input.jpg'.format(
                            self.testresult_dir + "4", j, i))
                    save_images(
                        np.reshape(output[1][i],
                                   newshape=(1, self.output_size,
                                             self.output_size, self.channel)),
                        [1, 1], '{}/{:02d}_{:2d}_recon.jpg'.format(
                            self.testresult_dir + "3", j, i))
                    save_images(
                        np.reshape(output[3][i],
                                   newshape=(1, self.output_size / 2,
                                             self.output_size / 2,
                                             self.channel)), [1, 1],
                        '{}/{:02d}_{:2d}_recon_local_right.jpg'.format(
                            self.testresult_dir, j, i))

            coord.request_stop()
            coord.join(threads)
コード例 #8
0
    def test(self):

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            self.saver = tf.train.Saver()

            ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir)
            print('Load checkpoint')
            if ckpt and ckpt.model_checkpoint_path:
                self.saver.restore(sess, ckpt.model_checkpoint_path)
                print('Load Succeed!')
            else:
                print('Do not exists any checkpoint,Load Failed!')
                exit()

            trainbatch, trainmask, _, _, testbatch, testmask = self.dataset.input(
            )
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            batch_num = self.opt.test_num
            for j in range(batch_num):
                x_img, x_img_pos, y_img, y_img_pos = sess.run(
                    [trainbatch, trainmask, testbatch, testmask])
                x_m, x_left_pos, x_right_pos = self.get_Mask_and_pos(x_img_pos)
                y_m, y_left_pos, y_right_pos = self.get_Mask_and_pos(y_img_pos)

                f_d = {
                    self.x: x_img,
                    self.xm: x_m,
                    self.x_left_p: x_left_pos,
                    self.x_right_p: x_right_pos,
                    self.y: y_img,
                    self.ym: y_m,
                    self.y_left_p: y_left_pos,
                    self.y_right_p: y_right_pos
                }

                output = sess.run([
                    self.x, self.xc, self.xo, self.yc, self.y, self.yo,
                    self.y2x, self.y2x_
                ],
                                  feed_dict=f_d)
                output_concat = self.Transpose(
                    np.array([
                        output[0], output[1], output[2], output[3], output[4],
                        output[5], output[6], output[7]
                    ]))
                local_output = sess.run([
                    self.xl_left, self.xl_right, self.yl_left, self.yl_right,
                    self._xl_left, self._xl_right, self._yl_left,
                    self._yl_right, self.y2x_left, self.y2x_right
                ],
                                        feed_dict=f_d)
                local_output_concat = self.Transpose(
                    np.array([
                        local_output[0], local_output[1], local_output[2],
                        local_output[3], local_output[4], local_output[5],
                        local_output[6], local_output[7], local_output[8],
                        local_output[9]
                    ]))

                inter_results = [
                    y_img,
                    np.ones(shape=[
                        self.opt.batch_size, self.opt.img_size,
                        self.opt.img_size, 3
                    ])
                ]
                inter_results1 = [
                    y_img,
                    np.ones(shape=[
                        self.opt.batch_size, self.opt.img_size,
                        self.opt.img_size, 3
                    ])
                ]
                inter_results2 = [
                    y_img,
                    np.ones(shape=[
                        self.opt.batch_size, self.opt.img_size,
                        self.opt.img_size, 3
                    ])
                ]
                inter_results3 = []

                for i in range(0, 11):
                    f_d = {
                        self.x:
                        x_img,
                        self.xm:
                        x_m,
                        self.x_left_p:
                        x_left_pos,
                        self.x_right_p:
                        x_right_pos,
                        self.y:
                        y_img,
                        self.ym:
                        y_m,
                        self.y_left_p:
                        y_left_pos,
                        self.y_right_p:
                        y_right_pos,
                        self.alpha:
                        np.reshape([i / 10.0],
                                   newshape=[self.opt.batch_size, 1])
                    }
                    output = sess.run(self._y2x_inter, feed_dict=f_d)
                    inter_results.append(output)

                for i in range(0, 15):
                    f_d = {
                        self.x:
                        x_img,
                        self.xm:
                        x_m,
                        self.x_left_p:
                        x_left_pos,
                        self.x_right_p:
                        x_right_pos,
                        self.y:
                        y_img,
                        self.ym:
                        y_m,
                        self.y_left_p:
                        y_left_pos,
                        self.y_right_p:
                        y_right_pos,
                        self.alpha:
                        np.reshape([i / 10.0],
                                   newshape=[self.opt.batch_size, 1])
                    }
                    output = sess.run(self._y2x_inter, feed_dict=f_d)
                    inter_results1.append(output)

                for i in range(11, 22):
                    f_d = {
                        self.x:
                        x_img,
                        self.xm:
                        x_m,
                        self.x_left_p:
                        x_left_pos,
                        self.x_right_p:
                        x_right_pos,
                        self.y:
                        y_img,
                        self.ym:
                        y_m,
                        self.y_left_p:
                        y_left_pos,
                        self.y_right_p:
                        y_right_pos,
                        self.alpha:
                        np.reshape([i / 10.0],
                                   newshape=[self.opt.batch_size, 1])
                    }
                    output = sess.run(self._y2x_inter, feed_dict=f_d)
                    inter_results2.append(output)

                for i in range(-10, 0):
                    f_d = {
                        self.x:
                        x_img,
                        self.xm:
                        x_m,
                        self.x_left_p:
                        x_left_pos,
                        self.x_right_p:
                        x_right_pos,
                        self.y:
                        y_img,
                        self.ym:
                        y_m,
                        self.y_left_p:
                        y_left_pos,
                        self.y_right_p:
                        y_right_pos,
                        self.alpha:
                        np.reshape([i / 10.0],
                                   newshape=[self.opt.batch_size, 1])
                    }
                    output = sess.run(self._y2x_inter, feed_dict=f_d)
                    inter_results3.append(output)

                save_images(
                    output_concat,
                    '{}/{:02d}.jpg'.format(self.opt.test_sample_dir, j))
                save_images(
                    local_output_concat,
                    '{}/{:02d}_local.jpg'.format(self.opt.test_sample_dir, j))
                save_images(
                    self.Transpose(np.array(inter_results)),
                    '{}/{:02d}inter1.jpg'.format(self.opt.test_sample_dir, j))
                save_images(
                    self.Transpose(np.array(inter_results1)),
                    '{}/{:02d}inter1_1.jpg'.format(self.opt.test_sample_dir,
                                                   j))

            coord.request_stop()
            coord.join(threads)
コード例 #9
0
    def train(self):

        log_vars = []
        log_vars.append(('D_loss', self.D_loss))
        log_vars.append(('G_loss', self.G_loss))

        vars = tf.trainable_variables()
        '''
        total_para = 0
        for variable in vars:
            shape = variable.get_shape()
            print(variable.name, shape)
            variable_para = 1
            for dim in shape:
                variable_para *= dim.value
            total_para += variable_para
        print("The total para", total_para)
        '''

        g_vars = getTrainVariable(vars, scope='encoder') + getTrainVariable(
            vars, scope='decoder')
        d_vars = getTrainVariable(vars, scope='discriminator')

        assert len(vars) == len(g_vars) + len(d_vars)

        saver = tf.train.Saver()
        for k, v in log_vars:
            tf.summary.scalar(k, v)

        opti_G = tf.train.RMSPropOptimizer(
            self.opt.lr_g * self.lr_decay).minimize(loss=self.G_loss,
                                                    var_list=g_vars)
        opti_D = tf.train.RMSPropOptimizer(
            self.opt.lr_g * self.lr_decay).minimize(loss=self.D_loss,
                                                    var_list=d_vars)
        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.opt.log_dir,
                                                   sess.graph)

            ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir)
            if ckpt and ckpt.model_checkpoint_path:
                start_step = int(
                    ckpt.model_checkpoint_path.split('model_',
                                                     2)[1].split('.', 2)[0])
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Load Successfully!', ckpt.model_checkpoint_path)
            else:
                start_step = 0

            step = start_step
            lr_decay = 1

            print("Start reading dataset")
            while step <= self.opt.niter:

                if step > self.opt.niter_decay and step % 2000 == 0:
                    lr_decay = (self.opt.niter - step
                                ) / float(self.opt.niter - self.opt.iter_decay)

                source_image_x_data, target_image_y1_data, cls_x, cls_y = self.data_ob.getNextBatch(
                )
                source_image_x = self.data_ob.getShapeForData(
                    source_image_x_data)
                target_image_y1 = self.data_ob.getShapeForData(
                    target_image_y1_data)

                f_d = {
                    self.x: source_image_x,
                    self.y_1: target_image_y1,
                    self.cls_x: cls_x,
                    self.cls_y: cls_y,
                    self.lr_decay: lr_decay
                }

                sess.run(opti_D, feed_dict=f_d)
                sess.run(opti_G, feed_dict=f_d)

                summary_str = sess.run(summary_op, feed_dict=f_d)
                summary_writer.add_summary(summary_str, step)

                if step % self.opt.display_freq == 0:

                    output_loss = sess.run([
                        self.D_loss, self.D_gan_loss, self.G_loss,
                        self.G_gan_loss, self.content_recon_loss,
                        self.feature_matching, self.l2_loss_d, self.l2_loss_g
                    ],
                                           feed_dict=f_d)
                    print(
                        "step %d, D_loss=%.4f, D_gan_loss=%.4f"
                        " G_loss=%.4f, G_gan_loss=%.4f, content_recon=%.4f, feautre_loss=%.4f, l2_loss=%.4f, lr_decay=%.4f"
                        %
                        (step, output_loss[0], output_loss[1], output_loss[2],
                         output_loss[3], output_loss[4], output_loss[5],
                         output_loss[6] + output_loss[7], lr_decay))

                if np.mod(step, self.opt.save_latest_freq) == 0:

                    f_d = {self.x: source_image_x, self.y_1: target_image_y1}

                    train_output_img = sess.run(
                        [self.x, self.y_1, self.tilde_x, self.x_recon],
                        feed_dict=f_d)

                    output_img = np.concatenate(
                        [img for img in train_output_img[0:4]], axis=0)

                    save_images(
                        output_img, [
                            output_img.shape[0] / self.opt.batchSize,
                            self.opt.batchSize
                        ], '{}/{:02d}_output_img.jpg'.format(
                            self.opt.sample_dir, step))

                if np.mod(step, self.opt.save_model_freq) == 0 and step != 0:
                    saver.save(
                        sess,
                        os.path.join(self.opt.checkpoints_dir,
                                     'model_{:06d}.ckpt'.format(step)))
                step += 1

            save_path = saver.save(
                sess,
                os.path.join(self.opt.checkpoints_dir,
                             'model_{:06d}.ckpt'.format(step)))
            summary_writer.close()

            print("Model saved in file: %s" % save_path)
コード例 #10
0
    def train(self):

        opti_G = tf.train.RMSPropOptimizer(
            self.g_learning_rate * self.lr_decay).minimize(
                loss=self.G_loss,
                var_list=self.encoder_vars + self.decoder_vars)
        opti_D = tf.train.RMSPropOptimizer(self.d_learning_rate *
                                           self.lr_decay).minimize(
                                               loss=self.D_loss,
                                               var_list=self.d_vars)
        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)

            ckpt = tf.train.get_checkpoint_state(self.write_model_path)
            if ckpt and ckpt.model_checkpoint_path:
                start_step = int(
                    ckpt.model_checkpoint_path.split('model_',
                                                     2)[1].split('.', 2)[0])
                self.saver.restore(sess, ckpt.model_checkpoint_path)
                print(ckpt.model_checkpoint_path)
                print('Load Successfully!')
            else:
                start_step = 0

            step = start_step
            lr_decay = 1

            print("Start read dataset")
            while step <= self.max_iters:

                if step > 20000 and step % 2000 == 0:
                    lr_decay = (self.max_iters - step) / float(self.max_iters -
                                                               20000)

                source_image_x_data, target_image_y1_data, target_image_y2_data, cls_x, cls_y = self.data_ob.getNextBatch(
                )
                source_image_x = self.data_ob.getShapeForData(
                    source_image_x_data)
                target_image_y1 = self.data_ob.getShapeForData(
                    target_image_y1_data)

                # cls_x = np.reshape(cls_x, newshape=[-1, 1])
                # cls_y = np.reshape(cls_y, newshape=[-1, 1])

                f_d = {
                    self.x: source_image_x,
                    self.y_1: target_image_y1,
                    self.cls_x: cls_x,
                    self.cls_y: cls_y,
                    self.lr_decay: lr_decay
                }

                sess.run(opti_D, feed_dict=f_d)
                sess.run(opti_G, feed_dict=f_d)

                summary_str = sess.run(summary_op, feed_dict=f_d)
                summary_writer.add_summary(summary_str, step)

                if step % 50 == 0:

                    output_loss = sess.run([
                        self.D_loss, self.D_gan_loss, self.G_loss,
                        self.G_gan_loss, self.content_recon_loss,
                        self.feature_matching
                    ],
                                           feed_dict=f_d)
                    print(
                        "step %d, D_loss=%.4f, D_gan_loss=%.4f"
                        " G_loss=%.4f, G_gan_loss=%.4f, content_recon=%.4f, feautre_loss=%.4f"
                        ", lr_decay=%.4f" %
                        (step, output_loss[0], output_loss[1], output_loss[2],
                         output_loss[3], output_loss[4], output_loss[5],
                         lr_decay))

                if np.mod(step, 1000) == 0:

                    f_d = {self.x: source_image_x, self.y_1: target_image_y1}

                    train_output_img = sess.run(
                        [self.x, self.y_1, self.tilde_x, self.x_recon],
                        feed_dict=f_d)

                    output_img = np.concatenate([
                        train_output_img[0], train_output_img[1],
                        train_output_img[2], train_output_img[3]
                    ],
                                                axis=0)

                    save_images(
                        output_img, [
                            output_img.shape[0] / self.batch_size,
                            self.batch_size
                        ], '{}/{:02d}_output_img.jpg'.format(
                            self.sample_path, step))

                if np.mod(step, 20000) == 0 and step != 0:
                    self.saver.save(
                        sess,
                        os.path.join(self.write_model_path,
                                     'model_{:06d}.ckpt'.format(step)))
                step += 1

            save_path = self.saver.save(
                sess,
                os.path.join(self.write_model_path,
                             'model_{:06d}.ckpt'.format(step)))
            summary_writer.close()

            print "Model saved in file: %s" % save_path
コード例 #11
0
    def train(self):

        self.t_vars = tf.trainable_variables()
        self.d_vars = [var for var in self.t_vars if 'D' in var.name]
        self.g_vars = [var for var in self.t_vars if 'G' in var.name]
        self.e_vars = [var for var in self.t_vars if 'encode' in var.name]
        assert len(self.t_vars) == len(self.d_vars + self.g_vars + self.e_vars)

        self.saver = tf.train.Saver()
        self.p_saver = tf.train.Saver(self.e_vars)

        opti_D = tf.train.AdamOptimizer(self.opt.lr_d * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2).\
                                        minimize(loss=self.D_loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(self.opt.lr_g * self.lr_decay, beta1=self.opt.beta1, beta2=self.opt.beta2).\
                                        minimize(loss=self.G_loss, var_list=self.g_vars)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            start_step = 0
            ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir)
            if ckpt and ckpt.model_checkpoint_path:
                start_step = int(
                    ckpt.model_checkpoint_path.split('model_',
                                                     2)[1].split('.', 2)[0])
                self.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print("")
                # try:
                #     #self.p_saver.restore(sess, os.path.join(self.opt.pretrain_path,
                #     #                                           'model_{:06d}.ckpt'.format(100000)))
                # except:
                #     print(" Self-Guided Model path may not be correct")

            #summary_op = tf.summary.merge_all()
            #summary_writer = tf.summary.FileWriter(self.opt.log_dir, sess.graph)
            step = start_step
            lr_decay = 1

            print("Start read dataset")

            image_path, train_images, train_eye_pos, test_images, test_eye_pos = self.dataset.input(
            )
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            print("Start entering the looping")
            real_test_batch, real_test_pos = sess.run(
                [test_images, test_eye_pos])

            while step <= self.opt.niter:

                if step > 20000 and step % 2000 == 0:
                    lr_decay = (self.opt.niter - step) / float(self.opt.niter -
                                                               20000)

                real_batch_image_path, x_data, x_p_data = sess.run(
                    [image_path, train_images, train_eye_pos])
                xm_data, x_left_p_data, x_right_p_data = self.get_Mask_and_pos(
                    x_p_data)

                f_d = {
                    self.x: x_data,
                    self.xm: xm_data,
                    self.x_left_p: x_left_p_data,
                    self.x_right_p: x_right_p_data,
                    self.lr_decay: lr_decay
                }

                # optimize D
                sess.run(opti_D, feed_dict=f_d)
                # optimize G
                sess.run(opti_G, feed_dict=f_d)
                #summary_str = sess.run(summary_op, feed_dict=f_d)
                #summary_writer.add_summary(summary_str, step)
                if step % 500 == 0:

                    if self.opt.is_ss:
                        output_loss = sess.run([
                            self.D_loss, self.G_loss, self.opt.lam_r *
                            self.recon_loss, self.opt.lam_p * self.percep_loss,
                            self.r_cls_loss, self.f_cls_loss
                        ],
                                               feed_dict=f_d)
                        print(
                            "step %d D_loss=%.8f, G_loss=%.4f, Recon_loss=%.4f, Percep_loss=%.4f, "
                            "Real_class_loss=%.4f, Fake_class_loss=%.4f, lr_decay=%.4f"
                            % (step, output_loss[0], output_loss[1],
                               output_loss[2], output_loss[3], output_loss[4],
                               output_loss[5], lr_decay))
                    else:
                        output_loss = sess.run([
                            self.D_loss, self.G_loss, self.opt.lam_r *
                            self.recon_loss, self.opt.lam_p * self.percep_loss
                        ],
                                               feed_dict=f_d)
                        print(
                            "step %d D_loss=%.8f, G_loss=%.4f, Recon_loss=%.4f, Percep_loss=%.4f, lr_decay=%.4f"
                            % (step, output_loss[0], output_loss[1],
                               output_loss[2], output_loss[3], lr_decay))

                if np.mod(step, 2000) == 0:

                    train_output_img = sess.run([
                        self.xl_left, self.xl_right, self.xc, self.yo, self.y,
                        self.yl_left, self.yl_right
                    ],
                                                feed_dict=f_d)

                    batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(
                        real_test_pos)
                    #for test
                    f_d = {
                        self.x: real_test_batch,
                        self.xm: batch_masks,
                        self.x_left_p: batch_left_eye_pos,
                        self.x_right_p: batch_right_eye_pos,
                        self.lr_decay: lr_decay
                    }

                    test_output_img = sess.run([self.xc, self.yo, self.y],
                                               feed_dict=f_d)
                    output_concat = self.Transpose(
                        np.array([
                            x_data, train_output_img[2], train_output_img[3],
                            train_output_img[4]
                        ]))
                    local_output_concat = self.Transpose(
                        np.array([
                            train_output_img[0], train_output_img[1],
                            train_output_img[5], train_output_img[6]
                        ]))
                    test_output_concat = self.Transpose(
                        np.array([
                            real_test_batch, test_output_img[0],
                            test_output_img[2], test_output_img[1]
                        ]))
                    save_images(
                        local_output_concat,
                        '{}/{:02d}_local_output.jpg'.format(
                            self.opt.sample_dir, step))
                    save_images(
                        output_concat,
                        '{}/{:02d}_output.jpg'.format(self.opt.sample_dir,
                                                      step))
                    save_images(
                        test_output_concat, '{}/{:02d}_test_output.jpg'.format(
                            self.opt.sample_dir, step))

                if np.mod(step, 20000) == 0:
                    self.saver.save(
                        sess,
                        os.path.join(self.opt.checkpoints_dir,
                                     'model_{:06d}.ckpt'.format(step)))

                step += 1

            save_path = self.saver.save(
                sess,
                os.path.join(self.opt.checkpoints_dir,
                             'model_{:06d}.ckpt'.format(step)))

            ############################
            #CREATING A TENSORBOARD BASED FILE FOR VISUALIZATION.
            ############################
            from tensorflow.summary import FileWriter

            # tf.train.import_meta_graph("checkpoints/model_100001.ckpt.meta")

            FileWriter("__tb", sess.graph)
            print("\n Graph File written\n")
            #####################################
            #Saving pb files: //Anant
            #####################################
            from tensorflow.python.tools import freeze_graph
            #####################################
            print("\n About to Freeze the graph\n")
            filename = "saved_model"
            directory = "log3_25_1"
            pbtxt_filename = filename + '.pbtxt'
            pbtxt_filepath = os.path.join(directory, pbtxt_filename)
            pb_filepath = os.path.join(directory, filename + '.pb')
            # This will only save the graph but the variables will not be saved.
            # You have to freeze your model first.
            tf.train.write_graph(graph_or_graph_def=sess.graph_def,
                                 logdir=directory,
                                 name=pbtxt_filename,
                                 as_text=True)

            # Freeze graph
            # Method 1
            # freeze_graph.freeze_graph(input_graph=pbtxt_filepath, input_saver='', input_binary=False, input_checkpoint=save_path, output_node_names='y', restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=pb_filepath, clear_devices=True, initializer_nodes='')
            print("\n Graph frozen\n")
            #summary_writer.close()

            coord.request_stop()
            coord.join(threads)

            print("Model saved in file: %s \n" % save_path)
            print("Saved_model saved in File: %s" % pb_filepath)
コード例 #12
0
    def test(self,
             freeze_model,
             num_custom_images,
             flag_save_images=True,
             custom_dataset=True):

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.saver = tf.train.Saver()

        with tf.Session(config=config) as sess:
            sess.run(init)
            ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir)
            print('Load checkpoint')
            if ckpt and ckpt.model_checkpoint_path:
                self.saver.restore(sess, ckpt.model_checkpoint_path)
                print('Load Succeed!')
            else:
                print('Do not exists any checkpoint,Load Failed!')
                exit()

            if custom_dataset == True:
                batch_num = num_custom_images
                testbatch, testmask = self.dataset.custom_test_input()
            else:
                batch_num = 3451 / self.opt.batch_size  #Have made batch size = 1
                _, _, _, testbatch, testmask = self.dataset.input()
            #_,_,_, testbatch, testmask = self.dataset.input()
            #testbatch, testmask = self.dataset.custom_test_input()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            #################################
            # STARTING TIMING
            ##################################
            start_time = time.time()

            for j in range(int(batch_num)):
                real_test_batch, real_eye_pos = sess.run([testbatch, testmask])
                batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(
                    real_eye_pos)
                f_d = {
                    self.x: real_test_batch,
                    self.xm: batch_masks,
                    self.x_left_p: batch_left_eye_pos,
                    self.x_right_p: batch_right_eye_pos
                }
                ############################
                # Saving the above 4 inputs consisting of Arrays to exportable files. Simply modify following things to load them:
                # 'wb' -> 'rb'
                # np.save -> var = np.load(file1)
                # - ANANT
                #############################
                # with open('/disk/projectEyes/GazeCorrection/log3_25_1/array_vars/placeholder_1', 'wb') as file1:
                #     np.save(file1, batch_right_eye_pos)
                # with open('/disk/projectEyes/GazeCorrection/log3_25_1/array_vars/placeholder', 'wb') as file1:
                #     np.save(file1, batch_left_eye_pos)
                # with open('/disk/projectEyes/GazeCorrection/log3_25_1/array_vars/placeholder_3', 'wb') as file1:
                #     np.save(file1, batch_masks)
                # with open('/disk/projectEyes/GazeCorrection/log3_25_1/array_vars/placeholder_2', 'wb') as file1:
                #     np.save(file1, real_test_batch)

                # #Loading back the variables from files.
                # with open('/disk/projectEyes/GazeCorrection/log3_25_1/array_vars/placeholder_2', 'rb') as file1:
                #     arr_plh1 = np.load(file1)

                output = sess.run([self.x, self.y], feed_dict=f_d)
                if flag_save_images == True:
                    #if j % 100 == 0 : #Considering the batch_num is 0
                    output_concat = self.Transpose(
                        np.array([output[0], output[1]]))
                    #save_images(output_concat, '{}/{:02d}.jpg'.format(self.opt.test_sample_dir, j))

                    ######################
                    # IF ONLY RESULTANT IMAGE NEEDS TO BE SAVED W/O CONCATINATION:
                    # -ANANT
                    ######################
                    # output_image = np.reshape(output[1], [256, 256, 3])
                    # save_images(output_image, '{}/out{}.jpg'.format("/disk/projectEyes/GazeCorrection/log3_25_1/test_sample_dir", j))

                    ######################
                    # IF CONCAT OF INPUT + OUTPUT NEEDS TO BE SAVED:
                    # - ANANT
                    ######################
                    save_images(
                        output_concat, '{}/{:02d}.jpg'.format(
                            "/disk/projectEyes/GazeCorrection/log3_25_1/test_sample_dir",
                            j))

            #################################
            # ENDING TIMING
            ##################################
            print(
                "\n \n INNER Time elapsed in GazeGan inference using TF of 3451 images = ",
                time.time() - start_time)

            if freeze_model == True:
                ############################
                #CREATING A TENSORBOARD BASED FILE FOR VISUALIZATION.
                ############################
                from tensorflow.summary import FileWriter

                # tf.train.import_meta_graph("checkpoints/model_100001.ckpt.meta")

                FileWriter("__tb_test", sess.graph)
                print("\n Graph File written\n")
                #####################################
                #Saving pb files: //Anant
                #####################################
                from tensorflow.python.tools import freeze_graph
                #####################################
                print("\n About to Freeze the graph\n")
                filename = "saved_model_test"
                directory = "log3_25_1"
                pbtxt_filename = filename + '.pbtxt'
                pbtxt_filepath = os.path.join(directory, pbtxt_filename)
                pb_filepath = os.path.join(directory, filename + '.pb')
                tf.train.write_graph(graph_or_graph_def=sess.graph_def,
                                     logdir=directory,
                                     name=pbtxt_filename,
                                     as_text=True)

                #freeze_graph.freeze_graph(input_graph=pbtxt_filepath, input_saver='', input_binary=False, input_checkpoint=tf.train.latest_checkpoint(self.opt.checkpoints_dir), output_node_names='add', restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=pb_filepath, clear_devices=True, initializer_nodes='')

                from tensorflow.python.framework import graph_io
                frozen = tf.graph_util.convert_variables_to_constants(
                    sess, sess.graph_def, ["add"])
                graph_io.write_graph(frozen,
                                     './log3_25_1/',
                                     'inference_graph_3_batch1.pb',
                                     as_text=False)
            coord.request_stop()
            coord.join(threads)