示例#1
0
文件: main.py 项目: oleob/DCGAN
def main():
    train_images, test_images, train_labels, test_labels = dataset.train_test_data(
        './PetImages_resize', batch_size)
    train_size = len(train_labels)
    test_size = len(test_labels)

    train_dataset = dataset.create_dataset(train_images, train_labels)
    train_dataset = train_dataset.cache().shuffle(
        buffer_size=train_size).batch(batch_size).repeat(
            num_epoch).make_one_shot_iterator().get_next()
    test_dataset = dataset.create_dataset(test_images, test_labels)
    test_dataset = test_dataset.cache().shuffle(
        buffer_size=10).batch(test_size).make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        model = DCGAN(sess,
                      train_dataset=train_dataset,
                      test_dataset=test_dataset,
                      train_size=train_size,
                      test_size=test_size,
                      batch_size=batch_size,
                      num_epoch=num_epoch)
        model.build_model()
        model.intialize_variables()
        #model.create_image_from_generator()
        model.train()
示例#2
0
def main(_):
    #  gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    #  run_config = tf.ConfigProto(gpu_options=gpu_options)
    #  run_config.gpu_options.allow_growth = True
    #  with tf.Session(config=run_config) as sess:
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS.input_height, FLAGS.input_width)
        dcgan.build_model()
        dcgan.train(FLAGS)
示例#3
0
def train(args):
    if args.input_width is None:
        args.input_width = args.input_height
    if args.output_width is None:
        args.output_width = args.output_height

    args.save_dir = args.save_dir + '_' + args.model
    args.temp_samples_dir = args.temp_samples_dir + '_' + args.model
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    if not os.path.exists(args.temp_samples_dir):
        os.makedirs(args.temp_samples_dir)

    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")), "config.pkl file does not exist in path %s" % args.init_from
        # get ckpt
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_args = cPickle.load(f)
        need_be_same=['y_dim','t_dim','z_dim','input_height','input_width','output_height','output_width']
        for checkme in need_be_same:
            assert vars(saved_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
    else:
        with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
            cPickle.dump(args, f)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth=False

    with tf.Session(config=run_config) as sess:
        if args.init_from is not None:
            init_from = args.init_from
            args = saved_args
            args.init_from = init_from
        if args.model == 'DCGAN':
            gan = DCGAN(
                sess,
                args.model,
                input_width=args.input_width,
                input_height=args.input_height,
                output_width=args.output_width,
                output_height=args.output_height,
                batch_size=args.batch_size,
                sample_num=args.batch_size,
                y_dim=args.y_dim,
                t_dim=args.t_dim,
                z_dim=args.z_dim,
                dataset_name=args.dataset,
                input_fname_pattern=args.input_fname_pattern,
                crop=args.crop,
                save_dir=args.save_dir,
                temp_samples_dir=args.temp_samples_dir,
                tag_filename=args.tag_filename,
                tag_filename_sp=args.tag_filename_sp)
        elif args.model == 'WGAN':
            gan = WGAN(
                sess,
                args.model,
                input_width=args.input_width,
                input_height=args.input_height,
                output_width=args.output_width,
                output_height=args.output_height,
                batch_size=args.batch_size,
                sample_num=args.batch_size,
                y_dim=args.y_dim,
                t_dim=args.t_dim,
                z_dim=args.z_dim,
                dataset_name=args.dataset,
                input_fname_pattern=args.input_fname_pattern,
                crop=args.crop,
                save_dir=args.save_dir,
                temp_samples_dir=args.temp_samples_dir,
                tag_filename=args.tag_filename,
                tag_filename_sp=args.tag_filename_sp,
                clipping_value = args.clipping_value)
        elif args.model == 'WGAN_v2':
            gan = WGAN_v2(
                sess,
                args.model,
                input_width=args.input_width,
                input_height=args.input_height,
                output_width=args.output_width,
                output_height=args.output_height,
                batch_size=args.batch_size,
                sample_num=args.batch_size,
                y_dim=args.y_dim,
                t_dim=args.t_dim,
                z_dim=args.z_dim,
                dataset_name=args.dataset,
                input_fname_pattern=args.input_fname_pattern,
                crop=args.crop,
                save_dir=args.save_dir,
                temp_samples_dir=args.temp_samples_dir,
                tag_filename=args.tag_filename,
                tag_filename_sp=args.tag_filename_sp,
                scale = args.scale)

        gan.build_model()

        if args.init_from is not None:
            gan.load(args.init_from)

        show_all_variables()
        gan.train(args)

        # Below is codes for visualization
        OPTION = 1
示例#4
0
def generate(args):
    assert os.path.isfile(
        os.path.join(args.init_from, "config.pkl")
    ), "config.pkl file does not exist in path %s" % args.init_from
    # open old config and check if models are compatible
    with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)

    # parse testing texts and encode them
    # with open(args.skipthoughts_model, 'rb') as f:
    #     skipthoughts_model = cPickle.load(f)
    skipthoughts_model = skipthoughts.load_model()
    test_dict = {}
    with open(args.testing_text, 'r') as f:
        lines = f.readlines()
        for line in lines:
            idx, desc = re.split(',', line)
            if saved_args.y_dim == 9600:
                hair = re.findall('[a-zA-Z]+ hair', desc, flags=0)
                hair = [re.sub(' hair', '', hair[0])]
                vec_hair = skipthoughts.encode(skipthoughts_model,
                                               hair,
                                               verbose=False)
                eyes = re.findall('[a-zA-Z]+ eyes', desc, flags=0)
                eyes = [re.sub(' eyes', '', eyes[0])]
                vec_eyes = skipthoughts.encode(skipthoughts_model,
                                               eyes,
                                               verbose=False)
                test_dict[idx] = np.concatenate([vec_hair, vec_eyes], 1)
            else:
                vec = skipthoughts.encode(skipthoughts_model, [desc.strip()],
                                          verbose=False)
                test_dict[idx] = vec

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = False

    with tf.Session(config=run_config) as sess:
        if saved_args.model == 'DCGAN':
            gan = DCGAN(sess,
                        saved_args.model,
                        input_width=saved_args.input_width,
                        input_height=saved_args.input_height,
                        output_width=saved_args.output_width,
                        output_height=saved_args.output_height,
                        batch_size=saved_args.batch_size,
                        sample_num=saved_args.batch_size,
                        y_dim=saved_args.y_dim,
                        t_dim=saved_args.t_dim,
                        z_dim=saved_args.z_dim,
                        dataset_name=saved_args.dataset,
                        input_fname_pattern=saved_args.input_fname_pattern,
                        crop=saved_args.crop,
                        save_dir=saved_args.save_dir,
                        temp_samples_dir=saved_args.temp_samples_dir,
                        tag_filename=saved_args.tag_filename,
                        tag_filename_sp=saved_args.tag_filename_sp,
                        infer=True)
        elif saved_args.model == 'WGAN':
            gan = WGAN(sess,
                       saved_args.model,
                       input_width=saved_args.input_width,
                       input_height=saved_args.input_height,
                       output_width=saved_args.output_width,
                       output_height=saved_args.output_height,
                       batch_size=saved_args.batch_size,
                       sample_num=saved_args.batch_size,
                       y_dim=saved_args.y_dim,
                       t_dim=saved_args.t_dim,
                       z_dim=saved_args.z_dim,
                       dataset_name=saved_args.dataset,
                       input_fname_pattern=saved_args.input_fname_pattern,
                       crop=saved_args.crop,
                       save_dir=saved_args.save_dir,
                       temp_samples_dir=saved_args.temp_samples_dir,
                       tag_filename=saved_args.tag_filename,
                       tag_filename_sp=saved_args.tag_filename_sp,
                       clipping_value=saved_args.clipping_value,
                       infer=True)
        elif saved_args.model == 'WGAN_v2':
            gan = WGAN_v2(sess,
                          saved_args.model,
                          input_width=saved_args.input_width,
                          input_height=saved_args.input_height,
                          output_width=saved_args.output_width,
                          output_height=saved_args.output_height,
                          batch_size=saved_args.batch_size,
                          sample_num=saved_args.batch_size,
                          y_dim=saved_args.y_dim,
                          t_dim=saved_args.t_dim,
                          z_dim=saved_args.z_dim,
                          dataset_name=saved_args.dataset,
                          input_fname_pattern=saved_args.input_fname_pattern,
                          crop=saved_args.crop,
                          save_dir=saved_args.save_dir,
                          temp_samples_dir=saved_args.temp_samples_dir,
                          tag_filename=saved_args.tag_filename,
                          tag_filename_sp=saved_args.tag_filename_sp,
                          scale=saved_args.scale,
                          infer=True)
        gan.build_model()
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()
        # show_all_variables()
        could_load, checkpoint_counter = gan.load(args.init_from)
        if could_load:
            # args.sample_dir = args.sample_dir + str(checkpoint_counter)
            if not os.path.exists(args.sample_dir):
                os.makedirs(args.sample_dir)
        else:
            print('load fail!!!')

        for idx, vec in test_dict.items():
            sample_z = np.random.uniform(-1,
                                         1,
                                         size=(gan.batch_size,
                                               saved_args.z_dim))
            sample_y = np.repeat(vec, gan.batch_size, axis=0)
            samples = sess.run(gan.sampler,
                               feed_dict={
                                   gan.z: sample_z,
                                   gan.y: sample_y
                               })
            print(samples.shape)
            ## Randomly take 5 images for each tag
            for sid in range(5):
                # manifold_h = int(np.ceil(np.sqrt(samples.shape[0])))
                # manifold_w = int(np.floor(np.sqrt(samples.shape[0])))
                # save_images(samples[sid], [1, 1],
                #             os.path.join(args.sample_dir, 'sample_{}_{}.jpg'.format(str(idx), str(sid+1))))
                scipy.misc.imsave(
                    os.path.join(
                        args.sample_dir,
                        'sample_{}_{}.jpg'.format(str(idx), str(sid + 1))),
                    samples[sid])
示例#5
0
from model import DCGAN
# from cocob_optimizer import COCOB

train_dir = 'train/dcgan/face_gen/'
k = 1
num_epoch = 30
save_epoch = 5
print_step = 500
summary_step = 2500

learning_rate_D = 0.0002
learning_rate_G = 0.001

if __name__ == '__main__':
    model = DCGAN(batch_size=64)
    model.build_model()

    # set up optimization
    opt_D = tf.train.AdamOptimizer(learning_rate=learning_rate_D, beta1=0.5)
    opt_G = tf.train.AdamOptimizer(learning_rate=learning_rate_G, beta1=0.5)

    collection_D = tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator'
    )  # collections of variables in variable_scope('discriminator')
    collection_G = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope='Generator')

    with tf.control_dependencies(collection_D):
        train_D = opt_D.minimize(loss=model.loss_D, var_list=model.var_D)
    with tf.control_dependencies(collection_G):
        train_G = opt_G.minimize(loss=model.loss_G,