def main(): if is_svhn is True: dataset = load_svhn() else: tf.logging.info("Read data from {}".format("MNIST_data")) dataset = input_data.read_data_sets('MNIST_data', one_hot=True) with tf.device(device): tf.logging.info("Build graph on {}".format(device)) opt_g, opt_c, real_data, c_loss, g_loss, fake_logit, true_logit = build_graph( ) # opt_g, opt_c, real_data = build_graph() merged_all = tf.summary.merge_all() saver = tf.train.Saver() config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 1.0 def next_feed_dict(): train_img = dataset.train.next_batch(batch_size)[0] train_img = 2 * train_img - 1 if is_svhn is not True: train_img = np.reshape(train_img, (-1, 28, 28)) npad = ((0, 0), (2, 2), (2, 2)) train_img = np.pad(train_img, pad_width=npad, mode='constant', constant_values=-1) train_img = np.expand_dims(train_img, -1) # [batch, 32, 32, 1] feed_dict = {real_data: train_img} return feed_dict with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) summary_writer = tf.summary.FileWriter(log_dir, sess.graph) for i in range(max_iter_step): if i < 25 or i % 500 == 0: citers = 100 else: citers = Citers run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() for j in range(citers): feed_dict = next_feed_dict() if i % 100 == 99 and j == 0: tf.logging.info("Save critic model at step %i" % i) _, merged, c_loss_np, fake_logit_np, true_logit_np = sess.run( [opt_c, merged_all, c_loss, fake_logit, true_logit], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) summary_writer.add_summary(merged, i) summary_writer.add_run_metadata( # save the meta data every 100 steps. run_metadata, 'critic_metadata {}'.format(i), i) tf.logging.info( "Step = {}, c_loss = {:.4f}, f_logit = {}, r_logit = {}" .format(i, c_loss_np, fake_logit_np[:5], true_logit_np[:5])) else: _, c_loss_np = sess.run([opt_c, c_loss], feed_dict=feed_dict) #if i % 100 == 99: #tf.logging.info("Optimize the critic model step {}/{}, c_loss = {:.4f}".format(j, i, c_loss_np)) feed_dict = next_feed_dict() if i % 100 == 99: tf.logging.info("Saving generator model as step %i" % i) _, merged, g_loss_np, fake_logit_np = sess.run( [opt_g, merged_all, g_loss, fake_logit], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) summary_writer.add_summary(merged, i) summary_writer.add_run_metadata( run_metadata, 'generator_metadata {}'.format(i), i) tf.logging.info( "Step = {}, g_loss = {:.4f}, f_logit = {}".format( i, g_loss_np, fake_logit_np[:5])) else: # tf.logging.info("Optimize the generator model!") sess.run(opt_g, feed_dict=feed_dict) if i % 1000 == 999: saver.save(sess, os.path.join(ckpt_dir, "model.ckpt"), global_step=i)
def main(): if is_svhn is True: dataset = load_svhn() else: dataset = input_data.read_data_sets('MNIST_data', one_hot=True) with tf.device(device): opt_g, opt_c, real_data = build_graph() merged_all = tf.summary.merge_all() saver = tf.train.Saver() config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.8 def next_feed_dict(): train_img = dataset.train.next_batch(batch_size)[0] train_img = 2 * train_img - 1 if is_svhn is not True: train_img = np.reshape(train_img, (-1, 28, 28)) npad = ((0, 0), (2, 2), (2, 2)) train_img = np.pad(train_img, pad_width=npad, mode='constant', constant_values=-1) train_img = np.expand_dims(train_img, -1) feed_dict = {real_data: train_img} return feed_dict with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) summary_writer = tf.summary.FileWriter(log_dir, sess.graph) for i in range(max_iter_step): if i < 25 or i % 500 == 0: citers = 100 else: citers = Citers for j in range(citers): feed_dict = next_feed_dict() if i % 100 == 99 and j == 0: run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() _, merged = sess.run([opt_c, merged_all], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) summary_writer.add_summary(merged, i) summary_writer.add_run_metadata( run_metadata, 'critic_metadata {}'.format(i), i) else: sess.run(opt_c, feed_dict=feed_dict) feed_dict = next_feed_dict() if i % 100 == 99: _, merged = sess.run([opt_g, merged_all], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) summary_writer.add_summary(merged, i) summary_writer.add_run_metadata( run_metadata, 'generator_metadata {}'.format(i), i) else: sess.run(opt_g, feed_dict=feed_dict) if i % 1000 == 999: saver.save(sess, os.path.join(ckpt_dir, "model.ckpt"), global_step=i)
beta2 = 0.9 max_iter_step = 20000 channels = 1 if dataset_name == 'mnist' else 3 log_path = './' + dataset_name + '/log_wgan' ckpt_path = './' + dataset_name + '/ckpt_wgan' ckpt_step_path = ckpt_path + '.step' ################################################################## if dataset_name == 'mnist': dataset = input_data.read_data_sets('MNIST_data', one_hot=True) ndf = ngf = 16 elif dataset_name == 'celeba': download_celeb_a() dataset = glob(os.path.join('./data/', 'celebA/*.jpg')) count = len(dataset) elif dataset_name == 'svhn': dataset = load_svhn() ################################################################## def get_image(image_path): image = Image.open(image_path) if image.size != (image_width, image_height): face_width = face_height = 108 j = (image.size[0] - face_width) // 2 i = (image.size[1] - face_height) // 2 image = image.crop([j, i, j + face_width, i + face_height]) image = image.resize([image_width, image_height], Image.BILINEAR) return np.array(image.convert('RGB')) def get_batches():