def main(_): pp.pprint(flags.FLAGS.__flags) if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) run_config = tf.ConfigProto() run_config.gpu_options.allow_growth=True with tf.Session(config=run_config) as sess: if FLAGS.dataset == 'mnist': dcgan = DCGAN( sess, input_width=FLAGS.input_width, input_height=FLAGS.input_height, output_width=FLAGS.output_width, output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, sample_num=FLAGS.batch_size, y_dim=10, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) else: dcgan = DCGAN( sess, input_width=FLAGS.input_width, input_height=FLAGS.input_height, output_width=FLAGS.output_width, output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, sample_num=FLAGS.batch_size, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) show_all_variables() if FLAGS.train: dcgan.train(FLAGS) else: if not dcgan.load(FLAGS.checkpoint_dir)[0]: raise Exception("[!] Train a model first, then run test mode") dcgan.vector_algebra(FLAGS, seed_one=FLAGS.seed_one, seed_two=FLAGS.seed_two) dcgan.interpolate(FLAGS,seed_one=FLAGS.seed_one, seed_two=FLAGS.seed_two)
#!/usr/bin/env python3.4 # # Irmak Sirer # License: MIT # 2016-09 import argparse import os import tensorflow as tf from model import DCGAN parser = argparse.ArgumentParser() parser.add_argument('--imgSize', type=int, default=64) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--checkpointDir', type=str, default='checkpoint') parser.add_argument('--outDir', type=str, default='interpolation') parser.add_argument('--vector1', type=str, default='') parser.add_argument('--vector2', type=str, default='') args = parser.parse_args() assert(os.path.exists(args.checkpointDir)) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: dcgan = DCGAN(sess, image_size=args.imgSize, checkpoint_dir=args.checkpointDir, lam=args.lam) dcgan.interpolate(config=args)