def main(): dcgan = DCGAN(s_size=s_size, batch_size=batch_size) train_im, total_imgs = load_image() total_batch = int(total_imgs / batch_size) losses = dcgan.loss(train_im) train_op = dcgan.train(losses, learning_rate=learning_rate) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.33) config = tf.ConfigProto(gpu_options=gpu_options, device_count={"CPU": 8}, inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) init = tf.global_variables_initializer() sess.run(init) g_saver = tf.train.Saver(dcgan.g.variables) d_saver = tf.train.Saver(dcgan.d.variables) if os.path.isdir(save_dir): g_saver.restore(sess, tf.train.latest_checkpoint(save_dir + '/g_model')) d_saver.restore(sess, tf.train.latest_checkpoint(save_dir + '/d_model')) else: os.mkdir(save_dir) sample_z = np.float32( np.random.uniform(-1, 1, [dcgan.batch_size, dcgan.z_dim])) images = dcgan.sample_images(5, 5, inputs=sample_z) print("Start training") for step in range(1, epochs + 1): start_time = time.time() for batch in range(total_batch): _, g_loss, d_loss = sess.run( [train_op, losses[dcgan.g], losses[dcgan.d]]) print("epochs {} loss = G: {:.8f}, D: {:.8f} run time:{:.4f} sec"\ .format(step, g_loss, d_loss, time.time()-start_time)) g_saver.save(sess, save_dir + '/g_model/g.ckpt', global_step=step) d_saver.save(sess, save_dir + '/d_model/d.ckpt', global_step=step) with open('./test/%05d.jpg' % step, 'wb') as f: f.write(sess.run(images)) coord.request_stop() coord.join(threads)
# """Directory where to write event logs and checkpoint.""") # tensorflow.app.flags.DEFINE_integer('num_examples_per_epoch_for_train', 5000, # """number of examples for train""") def get_images_batch(): return [] dcgan = DCGAN( g_depths=[8192, 4096, 2048, 1024, 512, 256, 128], d_depths=[64, 128, 256, 512, 1024, 2048, 4096], s_size=4, ) train_images = get_images_batch() losses = dcgan.loss(train_images) train_op = dcgan.train(losses) with tensorflow.Session() as sess: sess.run(tensorflow.global_variables_initializer()) for step in range(FLAGS.max_steps): _, g_loss_value, d_loss_value = sess.run( [train_op, losses[dcgan.g], losses[dcgan.d]]) images = dcgan.sample_images() with tensorflow.Session() as sess: # restore trained variables generated = sess.run(images) with open('output.binary', 'wb') as f:
def main(_): dcgan = DCGAN(s_size=6) traindata = read_decode(dcgan.batch_size, dcgan.s_size) losses = dcgan.loss(traindata) # feature matching graph = tf.get_default_graph() features_g = tf.reduce_mean( graph.get_tensor_by_name('dg/d/conv4/outputs:0'), 0) features_t = tf.reduce_mean( graph.get_tensor_by_name('dt/d/conv4/outputs:0'), 0) losses[dcgan.g] += tf.multiply(tf.nn.l2_loss(features_g - features_t), 0.05) tf.summary.scalar('g loss', losses[dcgan.g]) tf.summary.scalar('d loss', losses[dcgan.d]) train_op = dcgan.train(losses, learning_rate=0.0001) summary_op = tf.summary.merge_all() g_saver = tf.train.Saver(dcgan.g.variables, max_to_keep=15) d_saver = tf.train.Saver(dcgan.d.variables, max_to_keep=15) g_checkpoint_path = os.path.join(FLAGS.log_dir, 'g.ckpt') d_checkpoint_path = os.path.join(FLAGS.log_dir, 'd.ckpt') g_checkpoint_restore_path = os.path.join( FLAGS.log_dir, 'g.ckpt-' + str(FLAGS.latest_ckpt)) d_checkpoint_restore_path = os.path.join( FLAGS.log_dir, 'd.ckpt-' + str(FLAGS.latest_ckpt)) with tf.Session() as sess: summary_writer = tf.summary.FileWriter(FLAGS.log_dir, graph=sess.graph) sess.run(tf.global_variables_initializer()) # restore or initialize generator if os.path.exists(g_checkpoint_restore_path + '.meta'): print('Restoring variables:') for v in dcgan.g.variables: print(' ' + v.name) g_saver.restore(sess, g_checkpoint_restore_path) if FLAGS.is_train and not FLAGS.is_complete: # restore or initialize discriminator if os.path.exists(d_checkpoint_restore_path + '.meta'): print('Restoring variables:') for v in dcgan.d.variables: print(' ' + v.name) d_saver.restore(sess, d_checkpoint_restore_path) # setup for monitoring if not os.path.exists(FLAGS.images_dir): os.makedirs(FLAGS.images_dir) if not os.path.exists(FLAGS.log_dir): os.makedirs(FLAGS.log_dir) sample_z = sess.run( tf.random_uniform([dcgan.batch_size, dcgan.z_dim], minval=-1.0, maxval=1.0)) images = dcgan.sample_images(5, 5, inputs=sample_z) filename = os.path.join(FLAGS.images_dir, '000000.jpg') with open(filename, 'wb') as f: f.write(sess.run(images)) tf.train.start_queue_runners(sess=sess) for itr in range(FLAGS.latest_ckpt + 1, FLAGS.max_itr): start_time = time.time() _, g_loss, d_loss = sess.run( [train_op, losses[dcgan.g], losses[dcgan.d]]) duration = time.time() - start_time print('step: %d, loss: (G: %.8f, D: %.8f), time taken: %.3f' % (itr, g_loss, d_loss, duration)) if itr % 5000 == 0: # Images generated filename = os.path.join(FLAGS.images_dir, '%06d.jpg' % itr) with open(filename, 'wb') as f: f.write(sess.run(images)) # Summary summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, itr) # Checkpoints g_saver.save(sess, g_checkpoint_path, global_step=itr) d_saver.save(sess, d_checkpoint_path, global_step=itr) elif FLAGS.is_complete: # restore discriminator if os.path.exists(d_checkpoint_restore_path + '.meta'): print('Restoring variables:') for v in dcgan.d.variables: print(' ' + v.name) d_saver.restore(sess, d_checkpoint_restore_path) # Directory to save completed images if not os.path.exists(FLAGS.complete_dir): os.makedirs(FLAGS.complete_dir) # Create mask scale = 0.25 mask = np.ones(dcgan.image_shape) sz = dcgan.image_size l = int(dcgan.image_size * scale) u = int(dcgan.image_size * (1.0 - scale)) mask[l:u, l:u, :] = 0.0 masks = np.expand_dims(mask, axis=0) # Read actual images images = glob(os.path.join(FLAGS.complete_src, '*.jpg')) for idx in range(len(images)): image_src = get_image(images[idx], dcgan.image_size) image = np.expand_dims(image_src, axis=0) # Save image after crop (y) orig_fn = os.path.join( FLAGS.complete_dir, 'original_image_{:02d}.jpg'.format(idx)) imsave(image_src, orig_fn) # Save corrupted image (y . M) corrupted_fn = os.path.join( FLAGS.complete_dir, 'corrupted_image_{:02d}.jpg'.format(idx)) masked_image = np.multiply(image_src, mask) imsave(masked_image, corrupted_fn) zhat = np.random.uniform(-1, 1, size=(1, dcgan.z_dim)) v = 0 momentum = 0.9 lr = 0.01 for i in range(0, 10001): fd = { dcgan.zhat: zhat, dcgan.mask: masks, dcgan.image: image } run = [ dcgan.complete_loss, dcgan.grad_complete_loss, dcgan.G ] loss, g, G_imgs = sess.run(run, feed_dict=fd) v_prev = np.copy(v) v = momentum * v - lr * g[0] zhat += -momentum * v_prev + (1 + momentum) * v zhat = np.clip(zhat, -1, 1) if i % 100 == 0: hats_fn = os.path.join( FLAGS.complete_dir, 'hats_img_{:02d}_{:04d}.jpg'.format(idx, i)) save_images(G_imgs[0, :, :, :], hats_fn) inv_masked_hat_image = np.multiply( G_imgs, 1.0 - masks) completed = masked_image + inv_masked_hat_image complete_fn = os.path.join( FLAGS.complete_dir, 'completed_{:02d}_{:04d}.jpg'.format(idx, i)) save_images(completed[0, :, :, :], complete_fn) else: generated = sess.run(dcgan.sample_images(8, 8)) if not os.path.exists(FLAGS.images_dir): os.makedirs(FLAGS.images_dir) filename = os.path.join(FLAGS.images_dir, 'generated_image.jpg') with open(filename, 'wb') as f: print('write to %s' % filename) f.write(generated)
def main(_): dcgan = DCGAN(batch_size=FLAGS.batch_size, s_size=32, nb_channels=FLAGS.nb_channels) # ssize6 traindata = read_decode(FLAGS.data_dir, dcgan.batch_size) # , dcgan.s_size BelO, BelF = tf.split(traindata, [512, 512], axis=2) traindata = BelO Certainty = tf.div(tf.add(1.0, tf.add(BelF, BelO)), 2) CertaintyMask = tf.to_int32(Certainty > 0.4) # sess = tf.Session() # with sess.as_default(): # print(CertaintyMask.get_shape()) losses = dcgan.loss(traindata) # feature matching graph = tf.get_default_graph() features_g = tf.reduce_mean( graph.get_tensor_by_name('dg/d/conv4/outputs:0'), 0) features_t = tf.reduce_mean( graph.get_tensor_by_name('dt/d/conv4/outputs:0'), 0) losses[dcgan.g] += tf.multiply(tf.nn.l2_loss(features_g - features_t), 0.05) tf.summary.scalar('g_loss', losses[dcgan.g]) tf.summary.scalar('d_loss', losses[dcgan.d]) train_op = dcgan.train(losses, learning_rate=0.0001) summary_op = tf.summary.merge_all() g_saver = tf.train.Saver(dcgan.g.variables, max_to_keep=15) d_saver = tf.train.Saver(dcgan.d.variables, max_to_keep=15) g_checkpoint_path = os.path.join(FLAGS.log_dir, 'g.ckpt') d_checkpoint_path = os.path.join(FLAGS.log_dir, 'd.ckpt') g_checkpoint_restore_path = os.path.join( FLAGS.log_dir, 'g.ckpt-' + str(FLAGS.latest_ckpt)) d_checkpoint_restore_path = os.path.join( FLAGS.log_dir, 'd.ckpt-' + str(FLAGS.latest_ckpt)) with tf.Session() as sess: CertaintyMask = tf.squeeze(CertaintyMask) # CertaintyMask_npArray = CertaintyMask.eval() summary_writer = tf.summary.FileWriter(FLAGS.log_dir, graph=sess.graph) sess.run(tf.global_variables_initializer()) # restore or initialize generator if os.path.exists(g_checkpoint_restore_path + '.meta'): print('Restoring variables:') for v in dcgan.g.variables: print(' ' + v.name) g_saver.restore(sess, g_checkpoint_restore_path) if FLAGS.is_train and not FLAGS.is_complete: # restore or initialize discriminator if os.path.exists(d_checkpoint_restore_path + '.meta'): print('Restoring variables:') for v in dcgan.d.variables: print(' ' + v.name) d_saver.restore(sess, d_checkpoint_restore_path) # setup for monitoring if not os.path.exists(FLAGS.images_dir): os.makedirs(FLAGS.images_dir) if not os.path.exists(FLAGS.log_dir): os.makedirs(FLAGS.log_dir) sample_z = sess.run( tf.random_uniform([dcgan.batch_size, dcgan.z_dim], minval=-1.0, maxval=1.0)) images = dcgan.sample_images(5, 5, inputs=sample_z) filename = os.path.join(FLAGS.images_dir, '000000.jpg') with open(filename, 'wb') as f: f.write(sess.run(images)) tf.train.start_queue_runners(sess=sess) for itr in range(FLAGS.latest_ckpt + 1, FLAGS.max_itr): start_time = time.time() _, g_loss, d_loss = sess.run( [train_op, losses[dcgan.g], losses[dcgan.d]]) duration = time.time() - start_time print('step: %d, loss: (G: %.8f, D: %.8f), time taken: %.3f' % (itr, g_loss, d_loss, duration)) if itr % 5000 == 0: # Images generated filename = os.path.join(FLAGS.images_dir, '%06d.jpg' % itr) with open(filename, 'wb') as f: f.write(sess.run(images)) # Summary summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, itr) # Checkpoints g_saver.save(sess, g_checkpoint_path, global_step=itr) d_saver.save(sess, d_checkpoint_path, global_step=itr) elif FLAGS.is_complete: # restore discriminator if os.path.exists(d_checkpoint_restore_path + '.meta'): print('Restoring variables:') for v in dcgan.d.variables: print(' ' + v.name) d_saver.restore(sess, d_checkpoint_restore_path) # Directory to save completed images if not os.path.exists(FLAGS.complete_dir): os.makedirs(FLAGS.complete_dir) # Create mask if FLAGS.masktype == 'center': scale = 0.25 mask = np.ones(dcgan.image_shape) sz = dcgan.image_size l = int(sz * scale) u = int(sz * (1.0 - scale)) mask[l:u, l:u, :] = 0.0 if FLAGS.masktype == 'random': fraction_masked = 0.8 mask = np.ones(dcgan.image_shape) mask[np.random.random(dcgan.image_shape[:2]) < fraction_masked] = 0.0 if FLAGS.masktype == 'Uncertainty': mask = np.reshape(CertaintyMask_npArray, dcgan.image_shape) # Read actual images originals = glob(os.path.join(FLAGS.complete_src, '*.jpg')) batch_mask = np.expand_dims(mask, axis=0) for idx in range(len(originals)): image_src = get_image(originals[idx], dcgan.image_size, nb_channels=FLAGS.nb_channels) if FLAGS.nb_channels == 3: image = np.expand_dims(image_src, axis=0) elif FLAGS.nb_channels == 1: image = np.expand_dims(np.expand_dims(image_src, axis=3), axis=0) # Save original image (y) filename = os.path.join( FLAGS.complete_dir, 'original_image_{:02d}.jpg'.format(idx)) imsave(image_src, filename) # Save corrupted image (y . M) filename = os.path.join( FLAGS.complete_dir, 'corrupted_image_{:02d}.jpg'.format(idx)) if FLAGS.nb_channels == 3: masked_image = np.multiply(image_src, mask) imsave(masked_image, filename) elif FLAGS.nb_channels == 1: masked_image = np.multiply( np.expand_dims(image_src, axis=3), mask) imsave(masked_image[:, :, 0], filename) zhat = np.random.uniform(-1, 1, size=(1, dcgan.z_dim)) v = 0 momentum = 0.9 lr = 0.01 for i in range(0, 1001): fd = { dcgan.zhat: zhat, dcgan.mask: batch_mask, dcgan.image: image } run = [ dcgan.complete_loss, dcgan.grad_complete_loss, dcgan.G ] loss, g, G_imgs = sess.run(run, feed_dict=fd) v_prev = np.copy(v) v = momentum * v - lr * g[0] zhat += -momentum * v_prev + (1 + momentum) * v zhat = np.clip(zhat, -1, 1) if i % 100 == 0: filename = os.path.join( FLAGS.complete_dir, 'hats_img_{:02d}_{:04d}.jpg'.format(idx, i)) if FLAGS.nb_channels == 3: save_images(G_imgs[0, :, :, :], filename) if FLAGS.nb_channels == 1: save_images(G_imgs[0, :, :, 0], filename) inv_masked_hat_image = np.multiply( G_imgs, 1.0 - batch_mask) completed = masked_image + inv_masked_hat_image filename = os.path.join( FLAGS.complete_dir, 'completed_{:02d}_{:04d}.jpg'.format(idx, i)) if FLAGS.nb_channels == 3: save_images(completed[0, :, :, :], filename) if FLAGS.nb_channels == 1: save_images(completed[0, :, :, 0], filename) else: generated = sess.run(dcgan.sample_images(8, 8)) if not os.path.exists(FLAGS.images_dir): os.makedirs(FLAGS.images_dir) filename = os.path.join(FLAGS.images_dir, 'generated_image.jpg') with open(filename, 'wb') as f: print('write to %s' % filename) f.write(generated)
def main(argv=None): #Creating an object of the DCGAN dcgan = DCGAN(s_size=10) #Call to function to read data in .tfrecord format traindata = inputs(dcgan.batch_size, dcgan.s_size) print('Train data', traindata) # Calculating the losses losses = dcgan.loss(traindata) # Extracting the Generator and Discrimintor loss tf.summary.scalar('g loss', losses[dcgan.g]) tf.summary.scalar('d loss', losses[dcgan.d]) #Minimize the Generator and the Discrimintor losses train_op = dcgan.train(losses) summary_op = tf.summary.merge_all() #Creating objects to save the generator and discriminator states g_saver = tf.train.Saver(dcgan.g.variables) d_saver = tf.train.Saver(dcgan.d.variables) #Defining the directory to store the generator check points g_checkpoint_path = os.path.join(FLAGS.logdir, 'gckpt/') print('G checkpoint path: ', g_checkpoint_path) ##Defining the directory to store the discriminator check points d_checkpoint_path = os.path.join(FLAGS.logdir, 'dckpt/') print('D Checkpoint Path:',d_checkpoint_path) if not os.path.exists(g_checkpoint_path): os.makedirs(g_checkpoint_path) if not os.path.exists(d_checkpoint_path): os.makedirs(d_checkpoint_path) with tf.Session() as sess: summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph) newStepNo = 0 # restore or initialize generator sess.run(tf.global_variables_initializer()) gckpt = tf.train.get_checkpoint_state(g_checkpoint_path) if gckpt and gckpt.model_checkpoint_path: g_saver.restore(sess, gckpt.model_checkpoint_path) print('Model restored from ' + gckpt.model_checkpoint_path) newStepCheck = gckpt.model_checkpoint_path newStepNo = int(newStepCheck.split('-')[1]) #if os.path.exists(g_checkpoint_path): #print('restore variables for G:') #for v in dcgan.g.variables: #print(' ' + v.name) #g_saver.restore(sess, ckpt.model_checkpoint_path) #g_saver.restore(sess, g_checkpoint_path) #if os.path.exists(d_checkpoint_path): #print('restore variables for D:') #for v in dcgan.d.variables: # print(' ' + v.name) #d_saver.restore(sess, d_checkpoint_path) dckpt = tf.train.get_checkpoint_state(d_checkpoint_path) if dckpt and dckpt.model_checkpoint_path: d_saver.restore(sess, dckpt.model_checkpoint_path) print('Model restored from ' + dckpt.model_checkpoint_path) # setup for monitoring sample_z = sess.run(tf.random_uniform([dcgan.batch_size, dcgan.z_dim], minval=-1.0, maxval=1.0)) images = dcgan.sample_images(1, 1, inputs=sample_z) # start training tf.train.start_queue_runners(sess=sess) #for step in range(FLAGS.max_steps): while newStepNo <= FLAGS.max_steps: start_time = time.time() _, g_loss, d_loss = sess.run([train_op, losses[dcgan.g], losses[dcgan.d]]) duration = time.time() - start_time print('{}: step {:5d}, loss = (G: {:.8f}, D: {:.8f}) ({:.3f} sec/batch)'.format( datetime.now(), newStepNo, g_loss, d_loss, duration)) # save generated images if newStepNo % 100 == 0: # summary summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, newStepNo) # sample images filename = os.path.join(FLAGS.images_dir, '%05d.png' % newStepNo) with open(filename, 'wb') as f: f.write(sess.run(images)) # save variables if newStepNo % 500 == 0: g_saver.save(sess, g_checkpoint_path + 'g.ckpt', global_step=newStepNo) print('Save mode for G in checkpoint_path: ', g_checkpoint_path) d_saver.save(sess, d_checkpoint_path + 'd.ckpt', global_step=newStepNo) print('Save mode for D in checkpoint_path: ', d_checkpoint_path) #g_saver.save(sess, g_checkpoint_path, global_step=step) #d_saver.save(sess, d_checkpoint_path, global_step=step) newStepNo = newStepNo+1
def main(argv=None): """ Main function that calls a training batch sample and trains the model. """ if len(argv) < 2: print "Please input desired dataset in cmd line: `lsun` or `celeb`." sys.exit() dcgan = DCGAN(batch_size=64, s_size=4) traindata = None if argv[1] == 'lsun': # load input pipeline for LSUN dataset traindata = load_data.lsun_inputs(dcgan.batch_size, dcgan.s_size) elif argv[1] == 'celeb': # load input pipeline for CelebA dataset traindata = load_data.celeb_inputs(dcgan.batch_size, dcgan.s_size) losses = dcgan.loss(traindata) # feature mapping graph = tf.get_default_graph() features_g = tf.reduce_mean(graph.get_tensor_by_name('dg/d/conv4/outputs:0'), 0) features_t = tf.reduce_mean(graph.get_tensor_by_name('dt/d/conv4/outputs:0'), 0) # adding the regularization term losses[dcgan.g] += tf.multiply(tf.nn.l2_loss(features_t - features_g), 0.05) # train and summary tf.summary.scalar('g_loss', losses[dcgan.g]) tf.summary.scalar('d_loss', losses[dcgan.d]) train_op = dcgan.train(losses, learning_rate=0.0002) summary_op = tf.summary.merge_all() g_saver = tf.train.Saver(dcgan.g.variables) d_saver = tf.train.Saver(dcgan.d.variables) g_checkpoint_path = os.path.join(FLAGS.logdir, 'g.ckpt') d_checkpoint_path = os.path.join(FLAGS.logdir, 'd.ckpt') with tf.Session() as sess: summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph) # restore or initialize generator sess.run(tf.global_variables_initializer()) if os.path.exists(g_checkpoint_path): print('restore variables:') for v in dcgan.g.variables: print(' ' + v.name) g_saver.restore(sess, g_checkpoint_path) if os.path.exists(d_checkpoint_path): print('restore variables:') for v in dcgan.d.variables: print(' ' + v.name) d_saver.restore(sess, d_checkpoint_path) # setup for monitoring sample_z = sess.run(tf.random_uniform([dcgan.batch_size, dcgan.z_dim], minval=-1.0, maxval=1.0)) images = dcgan.sample_images(inputs=sample_z) # start training tf.train.start_queue_runners(sess=sess) for step in range(FLAGS.max_steps): start_time = time.time() _, g_loss, d_loss = sess.run([train_op, losses[dcgan.g], losses[dcgan.d]]) duration = time.time() - start_time print('{}: step {:5d}, loss = (G: {:.8f}, D: {:.8f}) ({:.3f} sec/batch)'.format( datetime.now(), step, g_loss, d_loss, duration)) # save generated images if step % 100 == 0: # summary summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) # sample images filename = os.path.join(FLAGS.images_dir, '%05d.jpg' % step) with open(filename, 'wb') as f: f.write(sess.run(images)) # save variables if step % 500 == 0: g_saver.save(sess, g_checkpoint_path, global_step=step) d_saver.save(sess, d_checkpoint_path, global_step=step)
def main(_): dcgan = DCGAN(batch_size=FLAGS.batch_size, s_size=6, nb_channels=FLAGS.nb_channels) traindata = read_decode(FLAGS.data_dir, dcgan.batch_size, dcgan.s_size) losses = dcgan.loss(traindata) # feature matching graph = tf.get_default_graph() features_g = tf.reduce_mean( graph.get_tensor_by_name('dg/d/conv4/outputs:0'), 0) features_t = tf.reduce_mean( graph.get_tensor_by_name('dt/d/conv4/outputs:0'), 0) losses[dcgan.g] += tf.multiply(tf.nn.l2_loss(features_g - features_t), 0.05) tf.summary.scalar('g_loss', losses[dcgan.g]) tf.summary.scalar('d_loss', losses[dcgan.d]) train_op = dcgan.train(losses, learning_rate=0.0001) summary_op = tf.summary.merge_all() g_saver = tf.train.Saver(dcgan.g.variables, max_to_keep=15) d_saver = tf.train.Saver(dcgan.d.variables, max_to_keep=15) g_checkpoint_path = os.path.join(FLAGS.log_dir, 'g.ckpt') d_checkpoint_path = os.path.join(FLAGS.log_dir, 'd.ckpt') g_checkpoint_restore_path = os.path.join( FLAGS.log_dir, 'g.ckpt-' + str(FLAGS.latest_ckpt)) d_checkpoint_restore_path = os.path.join( FLAGS.log_dir, 'd.ckpt-' + str(FLAGS.latest_ckpt)) with tf.Session() as sess: summary_writer = tf.summary.FileWriter(FLAGS.log_dir, graph=sess.graph) sess.run(tf.global_variables_initializer()) # restore or initialize generator if os.path.exists(g_checkpoint_restore_path + '.meta'): print('Restoring variables:') for v in dcgan.g.variables: print(' ' + v.name) g_saver.restore(sess, g_checkpoint_restore_path) if FLAGS.is_train: # restore or initialize discriminator if os.path.exists(d_checkpoint_restore_path + '.meta'): print('Restoring variables:') for v in dcgan.d.variables: print(' ' + v.name) d_saver.restore(sess, d_checkpoint_restore_path) # setup for monitoring if not os.path.exists(FLAGS.images_dir): os.makedirs(FLAGS.images_dir) if not os.path.exists(FLAGS.log_dir): os.makedirs(FLAGS.log_dir) sample_z = sess.run( tf.random_uniform([dcgan.batch_size, dcgan.z_dim], minval=-1.0, maxval=1.0)) images = dcgan.sample_images(5, 5, inputs=sample_z) filename = os.path.join(FLAGS.images_dir, '000000.jpg') with open(filename, 'wb') as f: f.write(sess.run(images)) tf.train.start_queue_runners(sess=sess) for itr in range(FLAGS.latest_ckpt + 1, FLAGS.max_itr): start_time = time.time() _, g_loss, d_loss = sess.run( [train_op, losses[dcgan.g], losses[dcgan.d]]) duration = time.time() - start_time f1 = open('./console.log', 'w+') print >> f1, ( 'step: %d, loss: (G: %.8f, D: %.8f), time taken: %.3f' % (itr, g_loss, d_loss, duration)) if itr % 5000 == 0: # Images generated filename = os.path.join(FLAGS.images_dir, '%06d.jpg' % itr) with open(filename, 'wb') as f: f.write(sess.run(images)) # Summary summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, itr) # Checkpoints g_saver.save(sess, g_checkpoint_path, global_step=itr) d_saver.save(sess, d_checkpoint_path, global_step=itr) else: generated = sess.run(dcgan.sample_images(8, 8)) if not os.path.exists(FLAGS.images_dir): os.makedirs(FLAGS.images_dir) filename = os.path.join(FLAGS.images_dir, 'generated_image.jpg') with open(filename, 'wb') as f: print('write to %s' % filename) f.write(generated)
EXAMPLE_NUM = len(data['imgNames']) FEATURE_DIM = len(data['imgFeatures'][0]) MAX_TIME_STEP = len(data['idxSentences'][0]) VOCABULARY_SIZE = len(data['vocabulary']) BATCH_NUM = int(math.ceil(EXAMPLE_NUM / float(100))) imgFeatures = tf.placeholder(tf.float32, [BATCH_SIZE, FEATURE_DIM]) idxSentences = tf.placeholder(tf.int32, [BATCH_SIZE, MAX_TIME_STEP]) input_lens = tf.placeholder(tf.int32, [BATCH_SIZE]) wordFeatures = tf.placeholder(tf.float32, [VOCABULARY_SIZE, FEATURE_DIM]) dcgan = DCGAN() losses = dcgan.loss(imgFeatures, idxSentences, input_lens, wordFeatures) train_op = dcgan.train(losses) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) permut = np.array(range(EXAMPLE_NUM)) for epoch in range(MAX_EPOCH): np.random.shuffle(permut) for step in range(BATCH_NUM): _, g_loss_value, d_loss_value = sess.run( [train_op, losses[dcgan.g], losses[dcgan.d]],