예제 #1
0
파일: main.py 프로젝트: chrisbyd/sharegan
def gengerate(train_dir, suffix=''):
    datacfg = data_cfg.get_config(args.dataset)
    model = get_model()
    log.info("Generating %i batches using suffix %s" %
             (args.num_generated_batches, suffix))
    noise = sample_generate_noise()
    fake_images = model.sampler(noise, False)
    init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log)
    # clean_init_op = tf.group(tf.global_variables_initializer(),
    #                          tf.local_variables_initializer())
    #saver = tf.train.Saver()
    tf.get_default_graph().finalize()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        # sess.run(clean_init_op)
        #saver.restore(sess,train_dir)
        sess.run(init_assign_op, feed_dict=init_feed_dict)
        output_list = []
        for i in range(args.num_generated_batches):
            output_list.append(sess.run(fake_images))
        output = np.concatenate(output_list, axis=0)
    output = denormalize(output)
    np.save(
        os.path.join(SAMPLE, args.model_name, args.dataset,
                     "X_gan_%s.npy" % (args.model_name + suffix)), output)
예제 #2
0
def view_pics(args, from_disk, generated_imgs, index):
    datacfg = data_cfg.get_config(args.dataset)
    if from_disk:
        generated_pics = np.load(
            os.path.join(DATASETS, args.dataset,
                         'X_gan_%s.npy' % args.model_name))
    else:
        generated_pics = generated_imgs
    gs = gridspec.GridSpec(8, 8)
    gs.update(wspace=0.05, hspace=0.05)
    imgs_to_show = np.array(generated_pics)
    imgs_to_show = np.reshape(imgs_to_show, [
        -1, datacfg.dataset.image_size * datacfg.dataset.image_size *
        datacfg.dataset.channels
    ])
    imgs_to_show = imgs_to_show[-16:]
    plt.figure(figsize=(4, 4))

    for i, img in enumerate(imgs_to_show):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        img = np.reshape(
            img, [datacfg.dataset.image_size, datacfg.dataset.image_size, 3])
        plt.imshow(img)
    plt.savefig('generated_pic%i.eps' % index, dpi=250)
    return
예제 #3
0
파일: main.py 프로젝트: chrisbyd/sharegan
def get_inception_score():
    datacfg = data_cfg.get_config(args.dataset)
    generated_pics = np.load(
        os.path.join(SAMPLE, args.model_name, args.dataset,
                     'X_gan_%s.npy' % args.model_name))
    #generated_pics=np.reshape(generated_pics,[-1,datacfg.dataset.image_size,datacfg.dataset.image_size,datacfg.dataset.channels])
    mean, var = inception_score.get_inception_score(generated_pics)
    log.info('the inception score is %s,with d=standard deviation %s' %
             (mean, var))
예제 #4
0
파일: gan.py 프로젝트: chrisbyd/sharegan
def construct_model():
    model=get_model()
    real_images,iter_fn=matcher.load_dataset('train',args.batch_size,args.dataset,32)
    datacfg=data_cfg.get_config(args.dataset)
    model.construct_variables(args.noise_dim,datacfg)
    input_noise=sample_noise()
    fake_img=model.get_generator(input_noise,datacfg)

    logits_real=model.get_discriminator(real_images,datacfg)
    logits_fake=model.get_discriminator(fake_img,datacfg)

    optimizer=get_optimizer(name="",optimizer=args.optimizer)
    D_loss,G_Loss=model.get_loss(logits_real,logits_fake)
    D_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'discriminator')
    G_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'generator')
    Discriminator_train_op=optimizer.minimize(D_loss,var_list=D_vars)
    Generator_train_op=optimizer.minimize(G_Loss,var_list=G_vars)
    return Discriminator_train_op,Generator_train_op,iter_fn,D_loss,G_Loss,fake_img
예제 #5
0
파일: main.py 프로젝트: chrisbyd/sharegan
import logging
from easydict import EasyDict as edict
from data_process import matcher
from model import origin_dcgan
import data_config as data_cfg
from evaluation import inception_score
from evaluation import frachet_inception_distance
import viewpics
from viewpics import image_manifold_size
from model import deep_dcgan
from model import dcgan_standard

sample_dir = os.path.join('./samples', args.model_name, args.dataset)
logging.config.dictConfig(get_logging_config(args.model_name))
log = logging.getLogger("gan")
data_config = data_cfg.get_config(args.dataset)


def sample_noise():
    return tf.random_uniform([args.batch_size, args.noise_dim],
                             minval=-1,
                             maxval=1)


def sample_generate_noise():
    return tf.random_uniform([100, args.noise_dim], minval=-1, maxval=1)


def get_optimizer(name, optimizer=args.optimizer):
    if args.lr_decay:
        global_step = tf.train.get_global_step()