コード例 #1
0
def z_real_(bz):
    if FLAGS.z_dist == 'u':
        return uniform(bz, FLAGS.z_dim, 0, 2)
    elif FLAGS.z_dist == 'g':
        return gaussian(bz, FLAGS.z_dim, 0, 2)
    elif FLAGS.z_dist == 'mg':
        return gaussian_mixture(bz, FLAGS.z_dim, 10)
    else:
        return swiss_roll(bz, FLAGS.z_dim, 10)
コード例 #2
0
ファイル: train.py プロジェクト: TahjidEshan/GAN_models
def main(run_load_from_file=False):
    # load MNIST images
    images, labels = dataset.load_test_images()

    # config
    opt = Operation()
    opt.check_dir(config.ckpt_dir, is_restart=False)
    opt.check_dir(config.log_dir, is_restart=True)

    max_epoch = 510
    num_trains_per_epoch = 500
    batch_size_u = 100

    # training
    with tf.device(config.device):
        h = build_graph()

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=True)
    sess_config.gpu_options.allow_growth = True
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.9
    saver = tf.train.Saver(max_to_keep=2)

    with tf.Session(config=sess_config) as sess:
        '''
         Load from checkpoint or start a new session

        '''
        if run_load_from_file:
            saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_dir))
            training_epoch_loss, _ = pickle.load(
                open(config.ckpt_dir + '/pickle.pkl', 'rb'))
        else:
            sess.run(tf.global_variables_initializer())
            training_epoch_loss = []

        # Recording loss per epoch
        process = Process()
        for epoch in range(max_epoch):
            process.start_epoch(epoch, max_epoch)
            '''
            Learning rate generator

            '''
            learning_rate = 0.0001

            # Recording loss per iteration
            sum_loss_reconstruction = 0
            sum_loss_discrminator_z = 0
            sum_loss_discrminator_img = 0
            sum_loss_generator_z = 0
            sum_loss_generator_img = 0
            process_iteration = Process()
            for i in range(num_trains_per_epoch):
                process_iteration.start_epoch(i, num_trains_per_epoch)
                # Inputs
                '''
                _l -> labeled
                _u -> unlabeled

                '''
                images_u = dataset.sample_unlabeled_data(images, batch_size_u)
                if config.distribution_sampler == 'swiss_roll':
                    z_true_u = sampler.swiss_roll(batch_size_u, config.ndim_z,
                                                  config.num_types_of_label)
                elif config.distribution_sampler == 'gaussian_mixture':
                    z_true_u = sampler.gaussian_mixture(
                        batch_size_u, config.ndim_z, config.num_types_of_label)
                elif config.distribution_sampler == 'uniform_desk':
                    z_true_u = sampler.uniform_desk(batch_size_u,
                                                    config.ndim_z,
                                                    radius=2)
                elif config.distribution_sampler == 'gaussian':
                    z_true_u = sampler.gaussian(batch_size_u,
                                                config.ndim_z,
                                                var=1)
                elif config.distribution_sampler == 'uniform':
                    z_true_u = sampler.uniform(batch_size_u,
                                               config.ndim_z,
                                               minv=-1,
                                               maxv=1)

                # reconstruction_phase
                _, loss_reconstruction = sess.run([h.opt_r, h.loss_r],
                                                  feed_dict={
                                                      h.x: images_u,
                                                      h.lr: learning_rate
                                                  })

                # adversarial phase for discriminator_z
                images_u_s = dataset.sample_unlabeled_data(
                    images, batch_size_u)
                _, loss_discriminator_z = sess.run([h.opt_dz, h.loss_dz],
                                                   feed_dict={
                                                       h.x: images_u,
                                                       h.z: z_true_u,
                                                       h.lr: learning_rate
                                                   })

                _, loss_discriminator_img = sess.run([h.opt_dimg, h.loss_dimg],
                                                     feed_dict={
                                                         h.x: images_u,
                                                         h.x_s: images_u_s,
                                                         h.lr: learning_rate
                                                     })

                # adversarial phase for generator
                _, loss_generator_z = sess.run([h.opt_e, h.loss_e],
                                               feed_dict={
                                                   h.x: images_u,
                                                   h.lr: learning_rate
                                               })

                _, loss_generator_img = sess.run([h.opt_d, h.loss_d],
                                                 feed_dict={
                                                     h.x: images_u,
                                                     h.lr: learning_rate
                                                 })

                sum_loss_reconstruction += loss_reconstruction
                sum_loss_discrminator_z += loss_discriminator_z
                sum_loss_discrminator_img += loss_discriminator_img
                sum_loss_generator_z += loss_generator_z
                sum_loss_generator_img += loss_generator_img

                if i % 1000 == 0:
                    process_iteration.show_table_2d(
                        i, num_trains_per_epoch, {
                            'reconstruction':
                            sum_loss_reconstruction / (i + 1),
                            'discriminator_z':
                            sum_loss_discrminator_z / (i + 1),
                            'discriminator_img':
                            sum_loss_discrminator_img / (i + 1),
                            'generator_z':
                            sum_loss_generator_z / (i + 1),
                            'generator_img':
                            sum_loss_generator_img / (i + 1),
                        })

            average_loss_per_epoch = [
                sum_loss_reconstruction / num_trains_per_epoch,
                sum_loss_discrminator_z / num_trains_per_epoch,
                sum_loss_discrminator_img / num_trains_per_epoch,
                sum_loss_generator_z / num_trains_per_epoch,
                sum_loss_generator_img / num_trains_per_epoch,
                (sum_loss_discrminator_z + sum_loss_discrminator_img) /
                num_trains_per_epoch,
                (sum_loss_generator_z + sum_loss_generator_img) /
                num_trains_per_epoch
            ]
            training_epoch_loss.append(average_loss_per_epoch)
            training_loss_name = [
                'reconstruction', 'discriminator_z', 'discriminator_img',
                'generator_z', 'generator_img', 'discriminator', 'generator'
            ]

            if epoch % 1 == 0:
                process.show_bar(
                    epoch, max_epoch, {
                        'loss_r': average_loss_per_epoch[0],
                        'loss_d': average_loss_per_epoch[5],
                        'loss_g': average_loss_per_epoch[6]
                    })

                plt.scatter_labeled_z(
                    sess.run(h.z_r, feed_dict={h.x: images[:1000]}),
                    [int(var) for var in labels[:1000]],
                    dir=config.log_dir,
                    filename='z_representation-{}'.format(epoch))

            if epoch % 10 == 0:
                saver.save(sess,
                           os.path.join(config.ckpt_dir, 'model_ckptpoint'),
                           global_step=epoch)
                pickle.dump((training_epoch_loss, training_loss_name),
                            open(config.ckpt_dir + '/pickle.pkl', 'wb'))
コード例 #3
0
def train():
    data_set = 'cifar10'
    prior = 'uniform'
    x_dim = 32
    x_chl = 3
    y_dim = 10
    z_dim = 64
    batch_size = 100
    num_epochs = 500 * 50
    step_epochs = 100  #int(num_epochs/100)
    learn_rate = 0.0005

    root_path = 'save/{}/{}'.format(data_set, prior)
    save_path = tools.make_save_directory(root_path)

    z = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='z')
    y = tf.placeholder(dtype=tf.float32, shape=[batch_size, y_dim], name='y')
    x_real = tf.placeholder(dtype=tf.float32,
                            shape=[batch_size, x_dim, x_dim, x_chl],
                            name='x')

    generator = Generator(batch_size=batch_size, z_dim=z_dim, dataset=data_set)
    discriminator = Discriminator(batch_size=batch_size, dataset=data_set)

    x_fake = generator.generate_on_cifar10(z, y, train=True)
    d_out_real = discriminator.discriminator_cifar10(x_real)
    d_out_fake = discriminator.discriminator_cifar10(x_fake)

    # discriminator loss
    D_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(d_out_real), logits=d_out_real))
    D_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(d_out_fake), logits=d_out_fake))
    with tf.control_dependencies([D_loss_fake, D_loss_real]):
        D_loss = D_loss_fake + D_loss_real
    # generator loss
    G_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(d_out_fake), logits=d_out_fake))

    # optimizers
    all_variables = tf.trainable_variables()
    g_var = [var for var in all_variables if 'generator' in var.name]
    d_var = [var for var in all_variables if 'discriminator' in var.name]
    optimizer = tf.train.AdamOptimizer(learn_rate)
    G_solver = optimizer.minimize(G_loss, var_list=g_var)
    D_solver = optimizer.minimize(D_loss, var_list=d_var)

    # read data
    train_data = Cifar10(train=True)
    file = open('{}/train.txt'.format(save_path), 'w')
    sess = tf.Session()

    # train the model
    sess.run(tf.global_variables_initializer())
    ave_loss_list = [0, 0, 0]
    cur_time = datetime.now()

    # for save sample images
    save_step = int(num_epochs / 100)
    z_sample = spl.uniform(100, z_dim)
    y_sample = spl.onehot_categorical(100, y_dim)

    # training process
    for epochs in range(1, num_epochs + 1):
        batch_x, batch_y = train_data.next_batch(batch_size)
        s_z_real = spl.uniform(batch_size, z_dim)

        for _ in range(1):
            sess.run(D_solver,
                     feed_dict={
                         z: s_z_real,
                         x_real: batch_x,
                         y: batch_y
                     })
        for _ in range(2):
            sess.run(G_solver, feed_dict={z: s_z_real, y: batch_y})

        loss_list = sess.run([D_loss_fake, D_loss_real, G_loss],
                             feed_dict={
                                 z: s_z_real,
                                 x_real: batch_x,
                                 y: batch_y
                             })
        ave_loss(ave_loss_list, loss_list, step_epochs)

        if epochs % save_step == 0:
            iter_counter = int(epochs / save_step)
            x_sample = sess.run(x_fake, feed_dict={z: z_sample, y: y_sample})
            tools.save_grid_images(x_sample,
                                   '{}/images/{}.png'.format(
                                       save_path, iter_counter),
                                   size=x_dim,
                                   chl=x_chl)

        # record information
        if epochs % step_epochs == 0:
            time_use = (datetime.now() - cur_time).seconds
            iter_counter = int(epochs / step_epochs)
            liner = "Epoch {:d}/{:d}, loss_dis_faker {:9f}, loss_dis_real {:9f}, loss_encoder {:9f} time_use {:f}" \
                .format(epochs, num_epochs, ave_loss_list[0], ave_loss_list[1], ave_loss_list[2], time_use)
            print(liner), file.writelines(liner + '\n')
            ave_loss_list = [0, 0, 0]  # reset to 0
            cur_time = datetime.now()

    # save model
    vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    saver = tf.train.Saver(var_list=vars)
    saver.save(sess, save_path='{}/model'.format(save_path))

    # close all
    file.close()
    sess.close()