示例#1
0
文件: trainer.py 项目: MasazI/dcgan
def train():
    with tf.Graph().as_default():
        # data
        dataset = traindataset.DataSet(DATA_DIR, SAMPLE_SIZE)
        # tfrecords inputs
        images, labels_t = dataset.csv_inputs(CSVFILE)

        z = tf.placeholder(tf.float32, [None, Z_DIM], name='z')

        dcgan = DCGAN("test", "./checkpoint")
        images_inf, logits, logits_, G_sum, z_sum, d_sum, d__sum = dcgan.inference(images, z)
        d_loss_fake, d_loss_real, d_loss_real_sum, d_loss_fake_sum, d_loss_sum, g_loss_sum, d_loss, g_loss = dcgan.loss(logits, logits_)

        # trainable variables
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'd_' in var.name]
        g_vars = [var for var in t_vars if 'g_' in var.name]

        # train operations
        d_optim = D_train_op(d_loss, d_vars)
        g_optim = G_train_op(g_loss, g_vars)

        # generate images
        generate_images = dcgan.generate_images(z, 4, 4)

        # summary
        g_sum = tf.merge_summary([z_sum, d__sum, G_sum, d_loss_fake_sum, g_loss_sum])
        d_sum = tf.merge_summary([z_sum, d_sum, d_loss_real_sum, d_loss_sum])
        #summary_op = tf.merge_all_summaries()

        # init operation
        init_op = tf.initialize_all_variables()

        # Session
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=LOG_DEVICE_PLACEMENT))
        writer = tf.train.SummaryWriter("./logs", sess.graph_def)

        # saver
        saver = tf.train.Saver(tf.all_variables())

        # run
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # sampling
        sample_z = np.random.uniform(-1, 1, size=(SAMPLE_SIZE, Z_DIM))

        # sample images
        #sample_images = dataset.get_sample()

        counter = 1
        start_time = time.time()

        for epoch in xrange(EPOCHS):
            for idx in xrange(0, dataset.batch_idxs):
                #batch_images = dataset.create_batch()
                batch_z = np.random.uniform(-1, 1, [BATCH_SIZE, Z_DIM]).astype(np.float32)

                # discriminative
                images_inf_eval, _, summary_str = sess.run([images_inf, d_optim, d_sum], {z: batch_z})
                writer.add_summary(summary_str, counter)

                #for i, image_inf in enumerate(images_inf_eval):
                #    print np.uint8(image_inf)
                #    print image_inf.shape
                #    #image_inf_reshape = image_inf.reshape([64, 64, 3])
                #    img = Image.fromarray(np.asarray(image_inf), 'RGB')
                #    print img
                #    img.save('verbose/%d_%d.png' % (counter, i))

                # generative
                _, summary_str = sess.run([g_optim, g_sum], {z: batch_z})
                writer.add_summary(summary_str, counter)

                # twice optimization
                _, summary_str = sess.run([g_optim, g_sum], {z: batch_z})
                writer.add_summary(summary_str, counter)

                errD_fake = sess.run(d_loss_fake, {z: batch_z})
                errD_real = sess.run(d_loss_real, {z: batch_z})
                errG = sess.run(g_loss, {z: batch_z})

                print("epochs: %02d %04d/%04d time: %4.4f, d_loss: %.8f, g_loss: %.8f" % (epoch, idx, dataset.batch_idxs,time.time() - start_time, errD_fake + errD_real, errG))

                if np.mod(counter, 10) == 1:
                    print("generate samples.")
                    generated_image_eval = sess.run(generate_images, {z: batch_z})
                    filename = os.path.join(FLAGS.sample_dir, 'out_%05d.png' % counter)
                    with open(filename, 'wb') as f:
                        f.write(generated_image_eval)
                counter += 1
        coord.request_stop()
        coord.join(threads)
        sess.close()
示例#2
0
def main():
    # load data
    bedroom = LSUNdataset(dirn='../data', category='bedroom')
    # fake_x = np.ones([128, 28, 28, 1], dtype=np.float32) * 0.1

    # condition
    k = 1             # # of discrim updates for each gen update
    l2 = 2.5e-5       # l2 weight decay
    b1 = 0.5          # momentum term of adam
    nc = 3            # # of channels in image
    ny = 10           # # of classes
    batch_size = 128  # # of examples in batch
    npx = 32         # # of pixels width/height of images
    nz = 100          # # of dim for Z
    ngfc = 1024       # # of gen units for fully connected layers
    ndfc = 1024       # # of discrim units for fully connected layers
    ngf = 64          # # of gen filters in first conv layer
    ndf = 64          # # of discrim filters in first conv layer
    nx = npx*npx*nc   # # of dimensions in X
    niter = 100       # # of iter at starting learning rate
    niter_decay = 100 # # of iter to linearly decay learning rate to zero
    lr = 0.0002       # initial learning rate for adam

    # tensorflow placeholder
    x = tf.placeholder(tf.float32, [None, npx, npx, nc])
    # y = tf.placeholder(tf.float32, [None, ny])          # for training w/ label
    y = None    # without label data
    # y_target = tf.placeholder(tf.float32, [None, ny])   # for image generation

    # graphs
    dcgan = DCGAN(batch_size=batch_size, s_size=4, z_dim=nz, y_dim=None)
    logits_list = dcgan.inference(x, y)     # (x, y)
    g_loss, d_loss = dcgan.loss(logits_list)
    train_op = dcgan.train(g_loss, d_loss, learning_rate=lr)

    # images
    images = dcgan.sample_images(label=None)
    # images = dcgan.sample_images()

    init = tf.global_variables_initializer()

    # Training
    n_epochs = 10
    with tf.Session() as sess:
        sess.run(init)

        # loop control
        n_sample = bedroom.num_examples
        n_loop = n_sample // batch_size     # for LSUN bedroom dataset, n_loop = 23696
        if n_sample % batch_size != 0:
            n_loop += 1
        
        for e in range(1, n_epochs+1):
            for i in range(n_loop):
                batch_x = bedroom.next_batch(batch_size, img_size=32)
                batch_img = batch_x.reshape([-1, 32, 32, 3])
                fd_train = {x: batch_img}
                # fd_train = {x: batch_img}
                sess.run(train_op, feed_dict=fd_train)
                g_loss_np, d_loss_np = sess.run([g_loss, d_loss], 
                                                feed_dict=fd_train)
                # print status
                if i % 10 == 0:
                    print((' ecpoch {:>5d}: ({:>8d} /{:>8d}) :'
                          'g_loss={:>10.4f}, d_loss={:>10.4f}').format(
                          e, i, n_loop, g_loss_np, d_loss_np))

                if i == 100:
                    break

            # Generate sample images after training
            if e in [1, 2, 5, 10]:
                fn_sample = '../work/samples/bedroom_' + str(e) + '.jpg'
                generated = sess.run(images)
                # generated = sess.run(images)
                with open(fn_sample, 'wb') as fp:
                    fp.write(generated)
示例#3
0
def train():
    with tf.Graph().as_default():
        # data
        dataset = traindataset.DataSet(DATA_DIR, SAMPLE_SIZE)
        # tfrecords inputs
        images, labels_t = dataset.csv_inputs(CSVFILE)

        z = tf.placeholder(tf.float32, [None, Z_DIM], name='z')

        dcgan = DCGAN("test", "./checkpoint")
        images_inf, logits, logits_, G_sum, z_sum, d_sum, d__sum = dcgan.inference(
            images, z)
        d_loss_fake, d_loss_real, d_loss_real_sum, d_loss_fake_sum, d_loss_sum, g_loss_sum, d_loss, g_loss = dcgan.loss(
            logits, logits_)

        # trainable variables
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'd_' in var.name]
        g_vars = [var for var in t_vars if 'g_' in var.name]

        # train operations
        d_optim = D_train_op(d_loss, d_vars)
        g_optim = G_train_op(g_loss, g_vars)

        # generate images
        generate_images = dcgan.generate_images(z, 4, 4)

        # summary
        g_sum = tf.merge_summary(
            [z_sum, d__sum, G_sum, d_loss_fake_sum, g_loss_sum])
        d_sum = tf.merge_summary([z_sum, d_sum, d_loss_real_sum, d_loss_sum])
        #summary_op = tf.merge_all_summaries()

        # init operation
        init_op = tf.initialize_all_variables()

        # Session
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=LOG_DEVICE_PLACEMENT))
        writer = tf.train.SummaryWriter("./logs", sess.graph_def)

        # saver
        saver = tf.train.Saver(tf.all_variables())

        # run
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # sampling
        sample_z = np.random.uniform(-1, 1, size=(SAMPLE_SIZE, Z_DIM))

        # sample images
        #sample_images = dataset.get_sample()

        counter = 1
        start_time = time.time()

        for epoch in xrange(EPOCHS):
            for idx in xrange(0, dataset.batch_idxs):
                #batch_images = dataset.create_batch()
                batch_z = np.random.uniform(-1, 1, [BATCH_SIZE, Z_DIM]).astype(
                    np.float32)

                # discriminative
                images_inf_eval, _, summary_str = sess.run(
                    [images_inf, d_optim, d_sum], {z: batch_z})
                writer.add_summary(summary_str, counter)

                #for i, image_inf in enumerate(images_inf_eval):
                #    print np.uint8(image_inf)
                #    print image_inf.shape
                #    #image_inf_reshape = image_inf.reshape([64, 64, 3])
                #    img = Image.fromarray(np.asarray(image_inf), 'RGB')
                #    print img
                #    img.save('verbose/%d_%d.png' % (counter, i))

                # generative
                _, summary_str = sess.run([g_optim, g_sum], {z: batch_z})
                writer.add_summary(summary_str, counter)

                # twice optimization
                _, summary_str = sess.run([g_optim, g_sum], {z: batch_z})
                writer.add_summary(summary_str, counter)

                errD_fake = sess.run(d_loss_fake, {z: batch_z})
                errD_real = sess.run(d_loss_real, {z: batch_z})
                errG = sess.run(g_loss, {z: batch_z})

                print(
                    "epochs: %02d %04d/%04d time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                    % (epoch, idx, dataset.batch_idxs,
                       time.time() - start_time, errD_fake + errD_real, errG))

                if np.mod(counter, 10) == 1:
                    print("generate samples.")
                    generated_image_eval = sess.run(generate_images,
                                                    {z: batch_z})
                    filename = os.path.join(FLAGS.sample_dir,
                                            'out_%05d.png' % counter)
                    with open(filename, 'wb') as f:
                        f.write(generated_image_eval)
                counter += 1
        coord.request_stop()
        coord.join(threads)
        sess.close()
示例#4
0
def main():
    # load data
    cifar = load_data('../data/')
    # fake_x = np.ones([128, 28, 28, 1], dtype=np.float32) * 0.1

    # condition
    k = 1  # # of discrim updates for each gen update
    l2 = 2.5e-5  # l2 weight decay
    b1 = 0.5  # momentum term of adam
    nc = 3  # # of channels in image
    ny = 10  # # of classes
    batch_size = 128  # # of examples in batch
    npx = 32  # # of pixels width/height of images
    nz = 100  # # of dim for Z
    ngfc = 1024  # # of gen units for fully connected layers
    ndfc = 1024  # # of discrim units for fully connected layers
    ngf = 64  # # of gen filters in first conv layer
    ndf = 64  # # of discrim filters in first conv layer
    nx = npx * npx * nc  # # of dimensions in X
    niter = 100  # # of iter at starting learning rate
    niter_decay = 100  # # of iter to linearly decay learning rate to zero
    lr = 0.0002  # initial learning rate for adam

    # tensorflow placeholder
    x = tf.placeholder(tf.float32, [None, npx, npx, nc])
    y = tf.placeholder(tf.float32, [None, ny])  # for training w/ label
    y_target = tf.placeholder(tf.float32, [None, ny])  # for image generation

    # graphs
    dcgan = DCGAN(batch_size=batch_size, s_size=4, z_dim=nz, y_dim=ny)
    logits_list = dcgan.inference(x, y)  # (x, y)
    g_loss, d_loss = dcgan.loss(logits_list)
    train_op = dcgan.train(g_loss, d_loss, learning_rate=lr)

    # images
    images = dcgan.sample_images(label=y_target)
    # images = dcgan.sample_images()

    init = tf.global_variables_initializer()

    # Training
    n_epochs = 300
    with tf.Session() as sess:
        sess.run(init)

        # loop control
        n_sample = cifar.train.num_examples
        n_loop = n_sample // batch_size
        if n_sample % batch_size != 0:
            n_loop += 1

        for e in range(1, n_epochs + 1):
            for i in range(n_loop):
                batch_x, batch_y = cifar.train.next_batch(batch_size)
                batch_img = batch_x.reshape([-1, 32, 32, 3])
                fd_train = {x: batch_img, y: batch_y}
                # fd_train = {x: batch_img}
                sess.run(train_op, feed_dict=fd_train)
                g_loss_np, d_loss_np = sess.run([g_loss, d_loss],
                                                feed_dict=fd_train)

            print('ecpoch {:>5d}: g_loss={:>11.4f}, d_loss={:>11.4f}'.format(
                e, g_loss_np, d_loss_np))

            # Generate sample images after training
            if e in [10, 20, 50, 100, 200, 300]:
                _, batch_yv = cifar.validation.next_batch(batch_size)
                fn_sample = '../work/samples/cifar_' + str(e) + '.jpg'
                generated = sess.run(images, feed_dict={y_target: batch_yv})
                # generated = sess.run(images)
                with open(fn_sample, 'wb') as fp:
                    fp.write(generated)