def run_train_loop(self): params = init.TrainingParamInitialization() self.gan = CloudGAN(params) self.train_img, self.train_alpha, self.train_reflectance =\ self.data_prepare.preprocess(self.gan.G_sample, self.gan.G_alpha, self.gan.G_relectance) self.validate_img, self.validate_alpha, self.validate_reflectance = \ self.data_prepare.preprocess(self.gan.G_sample, self.gan.G_alpha,self.gan.G_relectance) print('build matting net') with tf.name_scope('train'): train_op = self.train_op(self.train_img, self.train_alpha, self.train_reflectance) with tf.name_scope('validate'): validate_op = self.validate_op(self.validate_img, self.validate_alpha, self.validate_reflectance, reuse=True) sess = tf.Session() sess.run(tf.global_variables_initializer()) self.load_checkpoints(sess) step = 0 log_step = 50 saver = tf.train.Saver() merge_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(self.logdir, sess.graph) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, sess=sess) if not os.path.exists('rst'): os.mkdir('rst') while step < self.iter_step: feed_dict = None self.gan.run_train(sess) X_mb = utils.get_batch(self.gan.data, self.batch_size, 'cloudimage', self.img_size) Z_mb = utils.get_batch(self.gan.data, self.batch_size, 'cloudimage', self.img_size) BG_mb = utils.get_batch(self.gan.data, self.batch_size, 'background', self.img_size) feed_dict = { self.gan.X: X_mb, self.gan.Z: Z_mb, self.gan.BG: BG_mb } [_, train_loss, step] = sess.run(train_op, feed_dict=feed_dict) if np.mod(step, log_step) == 1: X_mb = utils.get_batch(self.gan.data, self.batch_size, 'cloudimage', self.img_size) Z_mb = utils.get_batch(self.gan.data, self.batch_size, 'cloudimage', self.img_size) BG_mb = utils.get_batch(self.gan.data, self.batch_size, 'background', self.img_size) feed_dict = { self.gan.X: X_mb, self.gan.Z: Z_mb, self.gan.BG: BG_mb } merges = [merge_op] + validate_op summary, image, alpha_image, reflectance_image, validate_loss = sess.run( merges, feed_dict=feed_dict) summary_writer.add_summary(summary, step) print("step:%d,loss:%f,validate_loss:%f" % (step, train_loss, validate_loss)) misc.imsave('rst/' + str(step) + '_image.png', image) misc.imsave('rst/' + str(step) + '_alpha.png', alpha_image) misc.imsave('rst/' + str(step) + '_foreground.png', reflectance_image) if np.mod(step, 500) == 1: saver.save(sess, os.path.join(self.model_path, 'model'), step)
With this project, you can train a model to solve the following inverse problems: - on MNIST and CIFAR-10 datasets for separating superimposed images. - image denoising on MNIST - remove speckle and streak noise in CAPTCHAs All the above tasks are trained w/ or w/o the help of pair-wise supervision. """ import tensorflow as tf import tensorflow.contrib.slim as slim import initializer as init # All parameters used in this file Params = init.TrainingParamInitialization() def generator(Y, is_training): # define the number of filters in each conv layer mm = 64 if Params.task_name in ['unmixing_mnist_mnist', 'denoising', 'captcha']: n_channels = 1 else: n_channels = 3 G_input_size = Params.G_input_size with tf.variable_scope('G_scope'):