Пример #1
0
def main(args):
    #
    save_dir = os.path.join(args.save_dir, args.model_type)
    img_dir = os.path.join(args.img_dir, args.model_type)
    log_dir = os.path.join(args.log_dir, args.model_type)
    train_dir = args.train_dir

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    mnist = utils.read_data_sets(args.train_dir)
    summary_writer = tf.summary.FileWriter(log_dir)
    config_proto = utils.get_config_proto()

    sess = tf.Session(config=config_proto)
    if args.model_type == "vae":
        model = VAE(args, sess, name="vae")
    elif args.model_type == "dcvae":
        model = DCVAE(args, sess, name="dcvae")

    total_batch = mnist.train.num_examples // args.batch_size

    for epoch in range(1, args.nb_epoch + 1):
        print "Epoch %d start with learning rate %f" % (
            epoch, model.learning_rate.eval(sess))
        print "- " * 50
        epoch_start_time = time.time()
        step_start_time = epoch_start_time
        for i in range(1, total_batch + 1):
            x_batch, y_batch = mnist.train.next_batch(args.batch_size)
            _, loss, loss_rec, loss_kl, global_step, summaries = model.train(
                x_batch)
            summary_writer.add_summary(summaries, global_step)
            step_start_time = time.time()

            if global_step % args.log_period == 0:
                print "global step %d, loss %.9f, loss_rec %.9f, loss_kl %.9f, time %.2fs" \
                    % (global_step, loss, loss_rec, loss_kl, time.time() - step_start_time)
                step_start_time = time.time()

        if args.anneal and epoch >= args.anneal_start:
            sess.run(model.learning_rate_decay_op)

        if epoch % args.save_period == 0:
            z = np.random.normal(size=[100, args.latent_dim])
            if args.model_type == "vae":
                gen_images = np.reshape(model.generate(z), (100, 28, 28, 1))
            elif args.model_type == "dcvae":
                gen_images = np.reshape(model.generate(z, 100),
                                        (100, 28, 28, 1))
            utils.save_images(gen_images, [10, 10],
                              os.path.join(img_dir, "sample%s.jpg" % epoch))

    model.saver.save(sess, os.path.join(save_dir, "model.ckpt"))
    print "Model stored...."
Пример #2
0
    def test_generate(self):
        model = VAE()

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            # generate with z_mu with batch size 5
            batch_size = 5
            z_mu = np.zeros((batch_size, 10), dtype=np.float32)
            xs = model.generate(sess, z_mu)
            self.assertEqual((batch_size, 4096), xs.shape)
Пример #3
0
class Sampler():
    def __init__(self):
        self.mnist = None
        self.model = VAE()
        self.model.load_model('save')
        self.z = self.generate_z()

    def get_random_mnist(self, with_label=False):
        if self.mnist == None:
            self.mnist = read_data_sets()
        if with_label == True:
            data, label = self.mnist.next_batch(1, with_label)
            return data[0], label[0]
        return self.mnist.next_batch(1)[0]

    def get_random_specific_mnist(self, label=2):
        m, l = self.get_random_mnist(with_label=True)
        for i in range(100):
            if l == label:
                break
            m, l = self.get_random_mnist(with_label=True)
        return m

    def generate_random_label(self, label):
        m = self.get_random_specific_mnist(label)
        self.show_image(m)
        self.show_image_from_z(self.encode(m))

    def generate_z(self):
        z = np.random.normal(size=(1, self.model.z_dim)).astype(np.float32)
        return z

    def encode(self, mnist_data):
        new_shape = [1] + list(mnist_data.shape)
        return self.model.transform(np.reshape(mnist_data, new_shape))

    def generate(self, z=None):
        if z is None:
            z = self.generate_z()
        else:
            z = np.reshape(z, (1, self.model.z_dim))
        self.z = z
        return self.model.generate(z)[0]

    def show_image(self, image_data):
        '''
    image_data is a tensor, in [height width depth]
    image_data is NOT the PIL.Image class
    '''
        plt.subplot(1, 1, 1)
        y_dim = image_data.shape[0]
        x_dim = image_data.shape[1]
        c_dim = 1
        if c_dim > 1:
            plt.imshow(image_data, interpolation='nearest')
        else:
            plt.imshow(image_data.reshape(y_dim, x_dim),
                       cmap='Greys',
                       interpolation='nearest')
        plt.axis('off')
        plt.show()

    def show_image_from_z(self, z):
        self.show_image(self.generate(z))

    def to_image(self, image_data):
        # convert to PIL.Image format from np array (0, 1)
        img_data = np.array(1 - image_data)
        y_dim = image_data.shape[0]
        x_dim = image_data.shape[1]
        c_dim = 1
        if c_dim > 1:
            img_data = np.array(img_data.reshape(
                (y_dim, x_dim, c_dim)) * 255.0,
                                dtype=np.uint8)
        else:
            img_data = np.array(img_data.reshape((y_dim, x_dim)) * 255.0,
                                dtype=np.uint8)
        im = Image.fromarray(img_data)
        return im

    def diff_image(self, image_data):
        # perform 2d differentiation on mnist image
        m2 = np.array(image_data)  # makes a copy
        m2[1:, 1:, :] = m2[1:, 1:, :] - m2[0:-1, 1:, :]
        m2[1:, 1:, :] = m2[1:, 1:, :] - m2[1:, 0:-1, :]
        return m2

    def integrate_image(self, image_data):
        # integrates differentiated batch back to mnist image
        m3 = np.array(image_data)
        m3 = m3.cumsum(axis=0)
        m3 = m3.cumsum(axis=1)
        return m3
Пример #4
0
class Sampler():
  def __init__(self):
    self.mnist = None
    self.model = VAE()
    self.model.load_model('save')
    self.z = self.generate_z()
  def get_random_mnist(self, with_label = False):
    if self.mnist == None:
      self.mnist = read_data_sets()
    if with_label == True:
      data, label = self.mnist.next_batch(1, with_label)
      return data[0], label[0]
    return self.mnist.next_batch(1)[0]
  def get_random_specific_mnist(self, label = 2):
    m, l = self.get_random_mnist(with_label = True)
    for i in range(100):
      if l == label:
        break
      m, l = self.get_random_mnist(with_label = True)
    return m
  def generate_random_label(self, label):
    m = self.get_random_specific_mnist(label)
    self.show_image(m)
    self.show_image_from_z(self.encode(m))
  def generate_z(self):
    z = np.random.normal(size=(1, self.model.z_dim)).astype(np.float32)
    return z
  def encode(self, mnist_data):
    new_shape = [1]+list(mnist_data.shape)
    return self.model.transform(np.reshape(mnist_data, new_shape))
  def generate(self, z=None):
    if z is None:
      z = self.generate_z()
    else:
      z = np.reshape(z, (1, self.model.z_dim))
    self.z = z
    return self.model.generate(z)[0]
  def show_image(self, image_data):
    '''
    image_data is a tensor, in [height width depth]
    image_data is NOT the PIL.Image class
    '''
    plt.subplot(1, 1, 1)
    y_dim = image_data.shape[0]
    x_dim = image_data.shape[1]
    c_dim = 1
    if c_dim > 1:
      plt.imshow(image_data, interpolation='nearest')
    else:
      plt.imshow(image_data.reshape(y_dim, x_dim), cmap='Greys', interpolation='nearest')
    plt.axis('off')
    plt.show()
  def show_image_from_z(self, z):
    self.show_image(self.generate(z))
  def to_image(self, image_data):
    # convert to PIL.Image format from np array (0, 1)
    img_data = np.array(1-image_data)
    y_dim = image_data.shape[0]
    x_dim = image_data.shape[1]
    c_dim = 1
    if c_dim > 1:
      img_data = np.array(img_data.reshape((y_dim, x_dim, c_dim))*255.0, dtype=np.uint8)
    else:
      img_data = np.array(img_data.reshape((y_dim, x_dim))*255.0, dtype=np.uint8)
    im = Image.fromarray(img_data)
    return im
  def diff_image(self, image_data):
    # perform 2d differentiation on mnist image
    m2 = np.array(image_data) # makes a copy
    m2[1:,1:,:] = m2[1:,1:,:]-m2[0:-1,1:,:]
    m2[1:,1:,:] = m2[1:,1:,:]-m2[1:,0:-1,:]
    return m2
  def integrate_image(self, image_data):
    # integrates differentiated batch back to mnist image
    m3 = np.array(image_data)
    m3 = m3.cumsum(axis=0)
    m3 = m3.cumsum(axis=1)
    return m3
Пример #5
0
for i in range(rows):
    plt.subplot(rows, cols, 2 * i + 1)
    plt.imshow(x[i].reshape(28, 28), vmin=0, vmax=1, cmap="Greys_r")
    plt.title("test input")
    plt.colorbar()
    plt.subplot(rows, cols, 2 * i + 2)
    plt.imshow(x_reconstr[i].reshape(28, 28), vmin=0, vmax=1, cmap="Greys_r")
    plt.title("reconstruct")
    plt.colorbar()
    plt.tight_layout()
plt.show()
""" check generation
"""
vae_model.eval()
noise = torch.randn(batch_size, net_arch["n_z"], device=device)
images = vae_model.generate(noise).cpu().detach()
torch.save(images, gen_images_path)
print(f"{gen_images_path} saved")
""" check latent space, in order to do that, we need to train another VAE model with n_z=2
"""
net_arch["n_z"] = 2

vae_model_2d = VAE(net_arch, lr, batch_size, device)
print(vae_model_2d)

if not os.path.exists(model_2d_save_path):
    vae_model_2d.train()
    n_epoch = 50
    for epoch in range(n_epoch):
        total_loss = 0.0
        count = 0