Example #1
0
File: main.py Project: SmartAI/GAN
def main(_):
    pp.pprint(FLAGS.__flags)
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS)

        if not os.path.exists(FLAGS.checkpoint_dir):
            os.makedirs(FLAGS.checkpoint_dir)
        if not os.path.exists(
                os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir()))
        if not os.path.exists(
                os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))

        if dcgan.checkpoint_exists():
            print "Loading checkpoints"
            if dcgan.load():
                print "Success"
            else:
                raise IOError("Could not read checkpoints from {}".format(
                    FLAGS.checkpoint_dir))
        else:
            if not FLAGS.train:
                raise IOError("No checkpoints found")
            print "No checkpoints found. Training from scratch"
            dcgan.load()

        if FLAGS.train:
            train(dcgan)

        print "Generating samples..."
        inference.sample_images(dcgan)
        inference.visualize_z(dcgan)
Example #2
0
def main(_):
    pp.pprint(FLAGS.__flags)

    # training/inference
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS)

        # path checks
        if not os.path.exists(FLAGS.checkpoint_dir):
            os.makedirs(FLAGS.checkpoint_dir)
        if not os.path.exists(
                os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))

        # load checkpoint if found
        if dcgan.checkpoint_exists():
            print("Loading checkpoints...")
            if dcgan.load():
                print "success!"
            else:
                raise IOError("Could not read checkpoints from {0}!".format(
                    FLAGS.checkpoint_dir))
        else:
            print "No checkpoints found. Training from scratch."
            dcgan.load()

        # train DCGAN
        if FLAGS.train:
            train(dcgan)
        else:
            dcgan.load()
Example #3
0
def load_model(model_params, contF=True):
    
    from dcgan import DCGAN
    import os
    model = DCGAN(model_params, ltype=os.environ['LTYPE'])
    if contF:
        # print '...Continuing from Last time'''
        from utils import unpickle
        _model = unpickle(os.environ['LOAD_PATH'])
        
        np_gen_params= [param.get_value() for param in _model.gen_network.params]
        np_dis_params= [param.get_value() for param in _model.dis_network.params]
        
        model.load(np_dis_params, np_gen_params, verbose=False)
        
    return model
Example #4
0
def main(_):
    pp.pprint(FLAGS.__flags)

    # training/inference
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS)

        # path checks
        if not os.path.exists(FLAGS.checkpoint_dir):
            os.makedirs(FLAGS.checkpoint_dir)
        if not os.path.exists(
                os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))
        if not os.path.exists(
                os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir()))

        # load checkpoint if found
        if dcgan.checkpoint_exists():
            print("Loading checkpoints...")
            if dcgan.load():
                print "success!"
            else:
                raise IOError("Could not read checkpoints from {0}!".format(
                    FLAGS.checkpoint_dir))
        else:
            if not FLAGS.train:
                raise IOError("No checkpoints found but need for sampling!")
            print "No checkpoints found. Training from scratch."
            dcgan.load()

        # train DCGAN
        if FLAGS.train:
            train(dcgan)

        # inference/visualization code goes here
        print "Generating samples..."
        inference.sample_images(dcgan)
        print "Generating visualizations of z..."
        inference.visualize_z(dcgan)
Example #5
0
File: main.py Project: cmcuza/DCGAN
def main(_):
    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    #run_config = tf.compat.v1.ConfigProto
    #run_config.gpu_options.allow_growth = True
    #run_config.gpu_options.visible_device_list = ''
    with tf.compat.v1.Session() as sess:
        dcgan = DCGAN(
            sess,
            input_width=FLAGS.input_width,
            input_height=FLAGS.input_height,
            output_width=FLAGS.output_width,
            output_height=FLAGS.output_height,
            batch_size=FLAGS.batch_size,
            sample_num=FLAGS.sample_num,
            dataset_name=FLAGS.dataset,
            input_fname_pattern=FLAGS.input_fname_pattern,
            crop=FLAGS.crop,
            checkpoint_dir=FLAGS.checkpoint_dir,
            data_dir=FLAGS.data_dir)

        model_vars = tf.trainable_variables()
        slim.model_analyzer.analyze_vars(model_vars, print_info=True)

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")
            if FLAGS.predict:
                dcgan.predict(FLAGS.predict_dataset)
            else:
                dcgan.test()
Example #6
0
        samples = sess.run(model.G, feed_dict={Z: noise, is_training: False})
        print(samples.min(), samples.max())

        save_binvox(out_dir + "{}.binvox".format(epoch),
                    samples[0, :, :, :, 0] > 0.9)
#        test_noise = get_moving_noise(sample_size, n_noise)
#        test_samples = sess.run(model.G, feed_dict={Z: test_noise, is_training: False})
#        path = "out8/{}/".format(epoch)
#        if not os.path.exists(path): os.makedirs(path)
#
#        for i, data in enumerate(samples):
#            save_binvox(path + "{}.binvox".format(i), data[:, :, :, 0] > 0.9)

#%%
""" test """
model.load(sess, log_dir)

epoch = 'test3'
sample_size = 10
noise = get_noise(sample_size, n_noise)
samples = sess.run(model.G, feed_dict={Z: noise, is_training: False})
#test_noise = get_moving_noise(sample_size, n_noise)
#test_samples = sess.run(model.G, feed_dict={Z: test_noise, is_training: False})
path = "out/{}/".format(epoch)
if not os.path.exists(path): os.makedirs(path)

for i, data in enumerate(samples):
    save_binvox(path + "{}.binvox".format(i), data[:, :, :, 0] > 0.9)

#import matplotlib.pyplot as plt
#for i, data in enumerate(samples):
                    default='center')
parser.add_argument('--center-scale', dest='center_scale', type=float, default=0.25)
parser.add_argument('imgs', type=str, nargs='+')
parser.add_argument('--log-l1-loss', dest='log_l1_loss', action='store_true')

args = parser.parse_args()


config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    model = DCGAN(sess, image_size=args.img_size,
                  checkpoint_dir=args.checkpoint_dir, lam=args.lam)
    # Need to have a loaded model
    tf.global_variables_initializer().run()
    assert(model.load(model.checkpoint_dir))
    
    # Construct mask
    image_shape = imread(args.imgs[0]).shape
    mask_type = args.mask_type
    if mask_type == 'random':
        fraction_masked = 0.2
        mask = np.ones(image_shape)
        mask[np.random.random(image_shape[:2]) < fraction_masked] = 0.0
    elif mask_type == 'center':
        center_scale = args.center_scale
        assert(center_scale <= 0.5)
        mask = np.ones(image_shape)
        l = int(image_shape[0] * center_scale)
        u = int(image_shape[0] * (1.0 - center_scale))
        mask[l:u, l:u, :] = 0.0