def run_generator(num, x1, x2, fig_name='sample.png'):
    with tf.Session() as sess:

        tf.sg_init(sess)

        # restore parameters
        tf.sg_restore(sess,
                      tf.train.latest_checkpoint('asset/train/infogan'),
                      category='generator')

        # run generator
        imgs = sess.run(gen, {
            target_num: num,
            target_cval_1: x1,
            target_cval_2: x2
        })

        # plot result
        _, ax = plt.subplots(10, 10, sharex=True, sharey=True)
        for i in range(10):
            for j in range(10):
                ax[i][j].imshow(imgs[i * 10 + j], 'gray')
                ax[i][j].set_axis_off()
        plt.savefig('asset/train/infogan/' + fig_name, dpi=600)
        tf.sg_info('Sample image saved to "asset/train/infogan/%s"' % fig_name)
        plt.close()
Exemple #2
0
    def generate(self, prev_midi):
        with tf.Session() as sess:
            tf.sg_init(sess)
            # saver = tf.train.Saver()
            # saver.restore(sess, tf.train.latest_checkpoint('save/train/small'))
            # KDK: choose self.next_token or self.preds
            # out = sess.run(self.next_token, {self.x: prev_midi})
            tf.sg_restore(sess, tf.train.latest_checkpoint('save/train/small'))
            out = sess.run(self.next_token, {self.x: prev_midi})

            return out
Exemple #3
0
def genIt(name='bird'):
    z = tf.random_normal((batch_size, rand_dim))
    gen = generator(z)
    with tf.Session() as sess:
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.sg_phase().assign(False)))
        tf.sg_restore(sess,
                      tf.train.latest_checkpoint('asset/train/gan'),
                      category=['generator', 'discriminator'])
        fake_features = []
        for i in range(100):
            fake_features.append(sess.run(gen))
    np.save('../data/fake_' + name + '_negative.npy',
            np.array(fake_features).reshape((-1, 4096)))
Exemple #4
0
def testIt():
    data = raw
    positive = np.array(data.label_train) > 0
    x = tf.placeholder(tf.float32, [None, 4096])
    y = tf.placeholder(tf.float32)
    disc_real = discriminator(x)
    accuracy = tf.reduce_mean(
        tf.cast(tf.equal(tf.cast(disc_real > 0.5, "float"), y), tf.float32))
    np.set_printoptions(precision=3, suppress=True)
    with tf.Session() as sess:
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.sg_phase().assign(False)))
        # restore parameters
        tf.sg_restore(sess,
                      tf.train.latest_checkpoint('asset/train/gan'),
                      category=['generator', 'discriminator'])
        ans = sess.run(disc_real, feed_dict={x: np.array(data.test)})
        print np.sum(ans > 0.5)
        np.save('dm_bird.npy', ans)
Exemple #5
0
z = tf.random_normal((batch_size, rand_dim))

# generator
gen = generator(z).sg_squeeze()

#
# draw samples
#

with tf.Session() as sess:

    tf.sg_init(sess)

    # restore parameters
    tf.sg_restore(sess,
                  tf.train.latest_checkpoint('asset/train/gan'),
                  category='generator')

    # run generator
    imgs = sess.run(gen)

    # plot result
    _, ax = plt.subplots(10, 10, sharex=True, sharey=True)
    for i in range(10):
        for j in range(10):
            ax[i][j].imshow(imgs[i * 10 + j], 'gray')
            ax[i][j].set_axis_off()
    plt.savefig('asset/train/gan/sample.png', dpi=600)
    tf.sg_info('Sample image saved to "asset/train/gan/sample.png"')
    plt.close()
    gen = (z.sg_dense(dim=1024).sg_dense(dim=7 * 7 * 128).sg_reshape(
        shape=(-1, 7, 7,
               128)).sg_upconv(dim=64).sg_upconv(dim=1,
                                                 act='sigmoid').sg_squeeze())

#
# draw samples
#

with tf.Session() as sess:

    tf.sg_init(sess)

    # restore parameters
    tf.sg_restore(sess,
                  tf.train.latest_checkpoint('asset/train/vae'),
                  category='decoder')

    # run generator
    imgs = sess.run(gen)

    # plot result
    _, ax = plt.subplots(10, 10, sharex=True, sharey=True)
    for i in range(10):
        for j in range(10):
            ax[i][j].imshow(imgs[i * 10 + j], 'gray')
            ax[i][j].set_axis_off()

    plt.savefig('asset/train/vae/sample.png', dpi=600)
    tf.sg_info('Sample image saved to "asset/train/vae/sample.png"')
    plt.close()