Example #1
0
    def test(self):

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            self.saver.restore(sess, os.path.join(self.write_model_path, 'model_{:06d}.ckpt'.format(100000)))
            batch1, mask1, _, _, testbatch, testmask = self.data_ob.input()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            batch_num = 200
            for j in range(batch_num):

                real_test_batch, real_test_mask = sess.run([testbatch, testmask])
                f_d = {self.input: real_test_batch, self.mask: real_test_mask}
                test_incomplete_img, test_x_tilde = sess.run([self.incomplete_img, self.x_tilde], feed_dict=f_d)
                test_output_concat = np.concatenate([real_test_batch, real_test_mask, test_incomplete_img, test_x_tilde], axis=0)
                save_images(test_output_concat, [test_output_concat.shape[0]/self.batch_size, self.batch_size],
                                        '{}/{:02d}_test_output.jpg'.format(self.test_sample_path, j))

            coord.request_stop()
            coord.join(threads)
Example #2
0
    def train(self):

        step_pl = tf.placeholder(tf.float32, shape=None)
        alpha_trans_assign = self.alpha_trans.assign(step_pl / self.max_iters)
        opti_D = tf.train.AdamOptimizer(self.d_learning_rate * self.lr_decay,
                                         beta1=self.beta1, beta2=self.beta2).minimize(loss=self.D_loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(self.g_learning_rate * self.lr_decay,
                                        beta1=self.beta1, beta2=self.beta2).minimize(loss=self.G_loss, var_list=self.ed_vars)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            step = (self.base_step - 4) * self.max_iters
            max_iters = step + self.max_iters

            print step
            print max_iters
            lr_decay = 1
            print("Start read dataset")
            batch1, mask1, batch2, mask2, testbatch, testmask = self.data_ob.input(image_size=self.output_size)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            print("Start entering the looping")
            print "model_read", self.model_read

            print self.pg
            if self.pg != 3 and self.pg != 8:
                if self.is_trans:
                    print "read variables"
                    self.read_saver.restore(sess, self.model_read)
                    self.rgb_saver.restore(sess, self.model_read)
                else:
                    print "read all variables"
                    self.saver.restore(sess, self.model_read)

            real_test_batch, real_test_mask = sess.run([testbatch, testmask])
            real_sample_alpha = 1
            this_step = 0
            while step <= max_iters:

                lr_decay = (self.all_iters - step) / float(self.all_iters)
                for i in range(self.n_critic):
                    real_batch_1, real_mask1 = sess.run([batch1, mask1])
                    f_d = {self.input: real_batch_1, self.mask: real_mask1, self.lr_decay: lr_decay}
                    # optimize D
                    sess.run(opti_D, feed_dict=f_d)

                # optimize M
                sess.run(opti_G, feed_dict=f_d)
                summary_str = sess.run(summary_op, feed_dict=f_d)
                summary_writer.add_summary(summary_str, this_step)

                sess.run(alpha_trans_assign, feed_dict={step_pl: this_step})

                if step % 100 == 0:

                    D_loss, D_loss_original, G_loss, Recon_loss, Alpha_tra = sess.run(
                        [self.D_loss, self.d_gan_loss_original, self.G_loss, self.recon_loss, self.alpha_trans],
                        feed_dict=f_d)
                    print("PG=%d step %d D_loss=%.4f, D_original=%.4f, G_loss=%.4f, Recon_loss=%.4f, lr_decay=%.4f, Alpha_tra=%.4f,"
                          "Real_sample_alpha=%.4f" % (self.pg,
                        step, D_loss, D_loss_original, G_loss, Recon_loss, lr_decay, Alpha_tra, real_sample_alpha))

                if np.mod(step, 1000) == 0:

                    incomplete_img1, x_tilde1 = sess.run([self.incomplete_img, self.x_tilde], feed_dict=f_d)
                    x_tilde1 = np.clip(x_tilde1, -1, 1)
                    #for test
                    f_d = {self.input: real_test_batch, self.mask: real_test_mask}
                    test_incomplete_img, test_x_tilde = sess.run([self.incomplete_img, self.x_tilde], feed_dict=f_d)

                    test_x_tilde = np.clip(test_x_tilde, -1, 1)
                    output_concat = np.concatenate([real_batch_1, real_mask1, incomplete_img1, x_tilde1], axis=0)
                    test_output_concat = np.concatenate([real_test_batch, real_test_mask, test_incomplete_img, test_x_tilde], axis=0)
                    save_images(output_concat, [output_concat.shape[0]/self.batch_size, self.batch_size],
                                            '{}/{:02d}_output3.jpg'.format(self.sample_path, step))
                    save_images(test_output_concat, [test_output_concat.shape[0]/self.batch_size, self.batch_size],
                                            '{}/{:02d}_test_output3.jpg'.format(self.sample_path, step))

                if np.mod(step, 10000) == 0 and step != 0:
                    self.saver.save(sess, self.model_write)

                step += 1
                this_step +=1

            print "model_write", self.model_write
            save_path = self.saver.save(sess, self.model_write)
            summary_writer.close()

            coord.request_stop()
            coord.join(threads)
            print "Model saved in file: %s" % save_path

        tf.reset_default_graph()
    def train(self):

        opti_D = tf.train.AdamOptimizer(self.d_learning_rate * self.lr_decay,
                                         beta1=self.beta1, beta2=self.beta2).minimize(loss=self.D_loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(self.g_learning_rate * self.lr_decay,
                                        beta1=self.beta1, beta2=self.beta2).minimize(loss=self.G_loss, var_list=self.ed_vars)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            step = 0
            lr_decay = 1

            try:
                self.load_saver.restore(sess, os.path.join(self.read_model_path, 'model_{:06d}.ckpt'.format(100000)))
            except Exception as e:
                print("Model path may not be correct")

            print("Start read dataset")
            batch_image_path, batch_image, eye_pos, testbatch_image, test_eye_pos = self.data_ob.input()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            print("Start entering the looping")
            real_test_batch, real_test_pos = sess.run([testbatch_image, test_eye_pos])

            while step <= self.max_iters:

                if step > 20000 and step % 2000 == 0:
                    lr_decay = (self.max_iters - step) / float(self.max_iters - 20000)

                real_batch_image_path, real_batch_image, real_eye_pos = sess.run([batch_image_path, batch_image, eye_pos])
                batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(real_eye_pos)

                f_d = {self.input: real_batch_image, self.input_masks: batch_masks,
                       self.input_left_labels: batch_left_eye_pos, self.input_right_labels: batch_right_eye_pos, self.lr_decay: lr_decay}

                # optimize D
                sess.run(opti_D, feed_dict=f_d)
                # optimize G
                sess.run(opti_G, feed_dict=f_d)
                summary_str = sess.run(summary_op, feed_dict=f_d)
                summary_writer.add_summary(summary_str, step)

                if step % 500 == 0:

                    output_loss = sess.run([self.D_loss, self.G_loss, self.lam_recon * self.recon_loss, self.lam_fp * self.fp_loss], feed_dict=f_d)
                    print("step %d D_loss1=%.8f, G_loss=%.4f, Recon_loss=%.4f, Fp_loss=%.4f, lr_decay=%.4f" % (
                                         step, output_loss[0], output_loss[1], output_loss[2], output_loss[3], lr_decay))

                if np.mod(step, 2000) == 0:
                    train_output_img = sess.run([self.local_input_left, self.local_input_right, self.incomplete_img, self.x_tilde,
                                                self.new_x_tilde, self.local_x_tilde_left, self.local_x_tilde_right], feed_dict=f_d)

                    batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(real_test_pos)
                    #for test
                    f_d = {self.input: real_test_batch, self.input_masks: batch_masks,
                           self.input_left_labels: batch_left_eye_pos, self.input_right_labels: batch_right_eye_pos, self.lr_decay: lr_decay}
                    test_output_img = sess.run([self.incomplete_img, self.x_tilde, self.new_x_tilde], feed_dict=f_d)

                    output_concat = np.concatenate([real_batch_image,
                                                    train_output_img[2], train_output_img[3], train_output_img[4]], axis=0)
                    local_output_concat = np.concatenate([train_output_img[0],
                                                          train_output_img[1], train_output_img[5], train_output_img[6]], axis=0)
                    test_output_concat = np.concatenate([real_test_batch,
                                                         test_output_img[0], test_output_img[2], test_output_img[1]], axis=0)

                    save_images(local_output_concat, [local_output_concat.shape[0] / self.batch_size, self.batch_size],
                                '{}/{:02d}_local_output.jpg'.format(self.sample_path, step))
                    save_images(output_concat, [output_concat.shape[0]/self.batch_size, self.batch_size],
                                            '{}/{:02d}_output.jpg'.format(self.sample_path, step))
                    save_images(test_output_concat, [test_output_concat.shape[0]/self.batch_size, self.batch_size],
                                            '{}/{:02d}_test_output.jpg'.format(self.sample_path, step))

                if np.mod(step, 20000) == 0 and step != 0:
                    self.saver.save(sess, os.path.join(self.write_model_path, 'model_{:06d}.ckpt'.format(step)))

                step += 1

            save_path = self.saver.save(sess, os.path.join(self.write_model_path, 'model_{:06d}.ckpt'.format(step)))
            summary_writer.close()

            coord.request_stop()
            coord.join(threads)

            print "Model saved in file: %s" % save_path
    def test3(self):

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            self.saver.restore(sess, os.path.join(self.write_model_path, 'model_{:06d}.ckpt'.format(100000)))
            _, batch1, mask1, testbatch, testmask = self.data_ob.input()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            batch_num = 1000 / self.batch_size
            for j in range(batch_num):

                real_test_batch, real_eye_pos = sess.run([testbatch, testmask])
                batch_masks, batch_left_eye_pos, batch_right_eye_pos = self.get_Mask_and_pos(real_eye_pos)
                f_d = {self.input: real_test_batch, self.input_masks: batch_masks,
                       self.input_left_labels: batch_left_eye_pos, self.input_right_labels: batch_right_eye_pos}
                test_incomplete_img, test_x_tilde, test_new_x_tilde, local_input_left, local_input_right, input_masks, r_masks \
                    = sess.run([self.incomplete_img, self.x_tilde, self.new_x_tilde,
                                self.local_input_left, self.local_input_right, self.input_masks,
                                1 - self.input_masks], feed_dict=f_d)

                for i in range(self.batch_size):
                    save_images(
                        np.reshape(input_masks[i], newshape=(1, self.output_size, self.output_size, self.channel)),
                        [1, 1],
                        '{}/{:02d}_{:2d}_masks.jpg'.format(self.test_sample_path, j, i))
                    save_images(
                        np.reshape(r_masks[i], newshape=(1, self.output_size, self.output_size, self.channel)),
                        [1, 1],
                        '{}/{:02d}_{:2d}_r_masks.jpg'.format(self.test_sample_path, j, i))
                    save_images(np.reshape(real_test_batch[i],
                                           newshape=(1, self.output_size, self.output_size, self.channel)), [1, 1],
                                '{}/{:02d}_{:2d}_real.jpg'.format(self.test_sample_path, j, i))
                    save_images(np.reshape(real_test_batch[i],
                                           newshape=(1, self.output_size, self.output_size, self.channel)), [1, 1],
                                '{}/{:02d}_{:2d}_real.jpg'.format(self.test_sample_path, j, i))
                    save_images(np.reshape(test_incomplete_img[i],
                                           newshape=(1, self.output_size, self.output_size, self.channel)), [1, 1],
                                '{}/{:02d}_{:2d}_in_compelete.jpg'.format(self.test_sample_path, j, i))
                    save_images(
                        np.reshape(test_x_tilde[i], newshape=(1, self.output_size, self.output_size, self.channel)),
                        [1, 1],
                        '{}/{:02d}_{:2d}_output.jpg'.format(self.test_sample_path, j, i))
                    save_images(np.reshape(test_new_x_tilde[i],
                                           newshape=(1, self.output_size, self.output_size, self.channel)), [1, 1],
                                '{}/{:02d}_{:2d}_new_output.jpg'.format(self.test_sample_path, j, i))
                    save_images(np.reshape(local_input_left[i],
                                           newshape=(1, self.output_size / 2, self.output_size / 2, self.channel)),
                                [1, 1],
                                '{}/{:02d}_{:2d}_local_input_left.jpg'.format(self.test_sample_path, j, i))
                    save_images(np.reshape(local_input_right[i],
                                           newshape=(1, self.output_size / 2, self.output_size / 2, self.channel)),
                                [1, 1],
                                '{}/{:02d}_{:2d}_local_input_right.jpg'.format(self.test_sample_path, j, i))

            coord.request_stop()
            coord.join(threads)