示例#1
0
def main():
    if is_svhn is True:
        dataset = load_svhn()
    else:
        tf.logging.info("Read data from {}".format("MNIST_data"))
        dataset = input_data.read_data_sets('MNIST_data', one_hot=True)
    with tf.device(device):
        tf.logging.info("Build graph on {}".format(device))
        opt_g, opt_c, real_data, c_loss, g_loss, fake_logit, true_logit = build_graph(
        )  # opt_g, opt_c, real_data = build_graph()
    merged_all = tf.summary.merge_all()
    saver = tf.train.Saver()
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 1.0

    def next_feed_dict():
        train_img = dataset.train.next_batch(batch_size)[0]
        train_img = 2 * train_img - 1
        if is_svhn is not True:
            train_img = np.reshape(train_img, (-1, 28, 28))
            npad = ((0, 0), (2, 2), (2, 2))
            train_img = np.pad(train_img,
                               pad_width=npad,
                               mode='constant',
                               constant_values=-1)
            train_img = np.expand_dims(train_img, -1)  # [batch, 32, 32, 1]
        feed_dict = {real_data: train_img}
        return feed_dict

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        for i in range(max_iter_step):
            if i < 25 or i % 500 == 0:
                citers = 100
            else:
                citers = Citers
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            for j in range(citers):
                feed_dict = next_feed_dict()
                if i % 100 == 99 and j == 0:
                    tf.logging.info("Save critic model at step %i" % i)
                    _, merged, c_loss_np, fake_logit_np, true_logit_np = sess.run(
                        [opt_c, merged_all, c_loss, fake_logit, true_logit],
                        feed_dict=feed_dict,
                        options=run_options,
                        run_metadata=run_metadata)
                    summary_writer.add_summary(merged, i)
                    summary_writer.add_run_metadata(  # save the meta data every 100 steps.
                        run_metadata, 'critic_metadata {}'.format(i), i)
                    tf.logging.info(
                        "Step = {}, c_loss = {:.4f}, f_logit = {}, r_logit = {}"
                        .format(i, c_loss_np, fake_logit_np[:5],
                                true_logit_np[:5]))
                else:
                    _, c_loss_np = sess.run([opt_c, c_loss],
                                            feed_dict=feed_dict)
                    #if i % 100 == 99:

#tf.logging.info("Optimize the critic model step {}/{}, c_loss = {:.4f}".format(j, i, c_loss_np))
            feed_dict = next_feed_dict()
            if i % 100 == 99:
                tf.logging.info("Saving generator model as step %i" % i)
                _, merged, g_loss_np, fake_logit_np = sess.run(
                    [opt_g, merged_all, g_loss, fake_logit],
                    feed_dict=feed_dict,
                    options=run_options,
                    run_metadata=run_metadata)
                summary_writer.add_summary(merged, i)
                summary_writer.add_run_metadata(
                    run_metadata, 'generator_metadata {}'.format(i), i)
                tf.logging.info(
                    "Step = {}, g_loss = {:.4f}, f_logit = {}".format(
                        i, g_loss_np, fake_logit_np[:5]))
            else:
                # tf.logging.info("Optimize the generator model!")
                sess.run(opt_g, feed_dict=feed_dict)
            if i % 1000 == 999:
                saver.save(sess,
                           os.path.join(ckpt_dir, "model.ckpt"),
                           global_step=i)
示例#2
0
def main():
    if is_svhn is True:
        dataset = load_svhn()
    else:
        dataset = input_data.read_data_sets('MNIST_data', one_hot=True)
    with tf.device(device):
        opt_g, opt_c, real_data = build_graph()
    merged_all = tf.summary.merge_all()
    saver = tf.train.Saver()
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=True)
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.8

    def next_feed_dict():
        train_img = dataset.train.next_batch(batch_size)[0]
        train_img = 2 * train_img - 1
        if is_svhn is not True:
            train_img = np.reshape(train_img, (-1, 28, 28))
            npad = ((0, 0), (2, 2), (2, 2))
            train_img = np.pad(train_img,
                               pad_width=npad,
                               mode='constant',
                               constant_values=-1)
            train_img = np.expand_dims(train_img, -1)
        feed_dict = {real_data: train_img}
        return feed_dict

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        for i in range(max_iter_step):
            if i < 25 or i % 500 == 0:
                citers = 100
            else:
                citers = Citers
            for j in range(citers):
                feed_dict = next_feed_dict()
                if i % 100 == 99 and j == 0:
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                    _, merged = sess.run([opt_c, merged_all],
                                         feed_dict=feed_dict,
                                         options=run_options,
                                         run_metadata=run_metadata)
                    summary_writer.add_summary(merged, i)
                    summary_writer.add_run_metadata(
                        run_metadata, 'critic_metadata {}'.format(i), i)
                else:
                    sess.run(opt_c, feed_dict=feed_dict)
            feed_dict = next_feed_dict()
            if i % 100 == 99:
                _, merged = sess.run([opt_g, merged_all],
                                     feed_dict=feed_dict,
                                     options=run_options,
                                     run_metadata=run_metadata)
                summary_writer.add_summary(merged, i)
                summary_writer.add_run_metadata(
                    run_metadata, 'generator_metadata {}'.format(i), i)
            else:
                sess.run(opt_g, feed_dict=feed_dict)
            if i % 1000 == 999:
                saver.save(sess,
                           os.path.join(ckpt_dir, "model.ckpt"),
                           global_step=i)
示例#3
0
beta2 = 0.9
max_iter_step = 20000
channels = 1 if dataset_name == 'mnist' else 3
log_path = './' + dataset_name + '/log_wgan'
ckpt_path = './' + dataset_name + '/ckpt_wgan'
ckpt_step_path = ckpt_path + '.step'
##################################################################
if dataset_name == 'mnist':
    dataset = input_data.read_data_sets('MNIST_data', one_hot=True)
    ndf = ngf = 16
elif dataset_name == 'celeba':
    download_celeb_a()
    dataset = glob(os.path.join('./data/', 'celebA/*.jpg'))
    count = len(dataset)
elif dataset_name == 'svhn':
    dataset = load_svhn()
##################################################################


def get_image(image_path):
    image = Image.open(image_path)
    if image.size != (image_width, image_height):
        face_width = face_height = 108
        j = (image.size[0] - face_width) // 2
        i = (image.size[1] - face_height) // 2
        image = image.crop([j, i, j + face_width, i + face_height])
        image = image.resize([image_width, image_height], Image.BILINEAR)
    return np.array(image.convert('RGB'))


def get_batches():