Ejemplo n.º 1
0
def sample_generator_images(hparams):
    """Sample random images from the generator"""

    # Create the generator
    _, x_hat, restore_path, restore_dict = mnist_model_def.vae_gen(hparams)

    # Get a session
    sess = tf.Session()

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)

    images = {}
    counter = 0
    rounds = int(math.ceil(hparams.num_input_images / hparams.batch_size))
    for _ in range(rounds):
        images_mat = sess.run(x_hat)
        for (_, image) in enumerate(images_mat):
            if counter < hparams.num_input_images:
                images[counter] = image
                counter += 1

    # Reset TensorFlow graph
    sess.close()
    tf.reset_default_graph()

    return images
Ejemplo n.º 2
0
def vae_gen_estimator(hparams):

    # Set up palceholders
    A = tf.placeholder(tf.float32, shape=(hparams.n_input, hparams.num_measurements), name='A')
    y = tf.placeholder(tf.float32, shape=(1, hparams.num_measurements), name='y')

    # Create the generator
    z, x_hat, restore_path, restore_dict = mnist_model_def.vae_gen(1)
    z_likelihood_loss = tf.reduce_sum(z ** 2)

    # measure the generator output
    y_hat = tf.matmul(x_hat, A, name='y_hat')
    measurement_loss = tf.reduce_mean((y - y_hat) ** 2)

    # define total loss
    loss = tf.add(measurement_loss/(hparams.noise_std**2), 10*z_likelihood_loss, name='loss')

    # Set up gradient descent wrt to z
    hparams.learning_rate = hparams.learning_rate * (hparams.noise_std**2)
    opt = utils.get_optimizer(hparams)
    update_op = opt.minimize(loss, var_list=[z], name='update_op')

    # Get a session
    sess = tf.Session()

    # Intialize and restore model parameters
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)

    def estimator(A_val, y_val, hparams):
        """Function that returns the estimated image"""
        measurement_loss_best = 1e10
        for _ in range(hparams.num_random_restarts):
            sess.run([z.initializer])
            for _ in range(hparams.max_update_iter):
                feed_dict = {A: A_val, y: y_val}
                _, measurement_loss_val = sess.run([update_op, measurement_loss], feed_dict=feed_dict)
            if measurement_loss_val < measurement_loss_best:
                measurement_loss_best = measurement_loss_val
                x_hat_best_val = sess.run(x_hat)
        return x_hat_best_val

    return estimator
Ejemplo n.º 3
0
def sample_generator_images(hparams):
    """Sample random images from the generator"""

    # Create the generator
    z, x_hat, restore_path, restore_dict, b3 = mnist_model_def.vae_gen(hparams)

    # Get a session
    sess = tf.Session()

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)

    images = {}
    zs = []
    counter = 0
    rounds = int(math.ceil(hparams.num_input_images / hparams.batch_size))
    for _ in range(rounds):
        z, images_mat = sess.run([z, x_hat])
        #print(sess.run(b3))
        for (_, image) in enumerate(images_mat):
            if counter < hparams.num_input_images:
                images[counter] = image
                zs.append(z[counter])
                counter += 1

    # Reset TensorFlow graph
    sess.close()
    tf.reset_default_graph()
    hparams.z_from_gen = np.asarray(zs)
    hparams.images_mat = images_mat
    np.save(
        utils.get_checkpoint_dir(hparams, hparams.model_types[0]) + 'z.npy',
        hparams.z_from_gen)
    np.save(
        utils.get_checkpoint_dir(hparams, hparams.model_types[0]) +
        'images.npy', hparams.images_mat)

    return images
Ejemplo n.º 4
0
def sample_generator_images(sample_size):
    """Sample random images from the generator"""

    # Create the generator
    _, x_hat, restore_path, restore_dict = mnist_model_def.vae_gen(sample_size)

    # Get a session
    sess = tf.Session()

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)
    images = sess.run(x_hat)
    images = {i: image for (i, image) in enumerate(images)}

    # Reset TensorFlow graph
    sess.close()
    tf.reset_default_graph()

    return images
Ejemplo n.º 5
0
def vae_estimator(hparams):

    # Get a session
    sess = tf.Session()

    # Set up palceholders
    #A = tf.placeholder(tf.float32, shape=(hparams.batch_size, 100), name='A')
    y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.n_input), name='y_batch')

    # Create the generator
    # TODO: Move z_batch definition here
    z_batch,x_hat_batch, restore_path, restore_dict = mnist_model_def.vae_gen(hparams)

    # measure the estimate

    y_hat_batch = tf.identity(x_hat_batch,name='y2_batch')

    # define all losses
    m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1)
    m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)
    zp_loss_batch = tf.reduce_sum(z_batch**2, 1)

    # define total loss
    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                     + hparams.mloss2_weight * m_loss2_batch \
                     + hparams.zprior_weight * zp_loss_batch
    total_loss = tf.reduce_mean(total_loss_batch)

    # Compute means for logging
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)

    # Set up gradient descent
    var_list = [z_batch]
    global_step = tf.Variable(0, trainable=False, name='global_step')
    learning_rate = utils.get_learning_rate(global_step, hparams)
    opt = utils.get_optimizer(learning_rate, hparams)
    update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op')
    opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)

    def estimator(y_batch_val,z_batch_val,hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)
        assign_z_opt_op = z_batch.assign(z_batch_val)

        feed_dict = {y_batch: y_batch_val}

        for i in range(hparams.num_random_restarts):
            sess.run(opt_reinit_op)
            sess.run(assign_z_opt_op)
            for j in range(hparams.max_update_iter):
                _, lr_val, total_loss_val, \
                m_loss1_val, \
                m_loss2_val, \
                zp_loss_val = sess.run([update_op, learning_rate, total_loss,
                                        m_loss1,
                                        m_loss2,
                                        zp_loss], feed_dict=feed_dict)
                logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}'
                print logging_format.format(i, j, lr_val, total_loss_val,
                                            m_loss1_val,
                                            m_loss2_val,
                                            zp_loss_val)

            x_hat_batch_val,z_batch_val, total_loss_batch_val = sess.run([x_hat_batch,z_batch, total_loss_batch], feed_dict=feed_dict)
            best_keeper.report(x_hat_batch_val,z_batch_val, total_loss_batch_val)
        return best_keeper.get_best()

    return estimator
Ejemplo n.º 6
0
def vae_estimator(hparams):

    # Get a session
    sess = tf.Session()

    # Set up palceholders
    A = tf.placeholder(tf.float32,
                       shape=(hparams.n_input, hparams.num_measurements),
                       name='A')
    y_batch = tf.placeholder(tf.float32,
                             shape=(hparams.batch_size,
                                    hparams.num_measurements),
                             name='y_batch')

    # Create the generator
    # TODO: Move z_batch definition here
    z_batch, x_hat_batch, restore_path, restore_dict = mnist_model_def.vae_gen(
        hparams.batch_size)

    # measure the estimate
    y_hat_batch = tf.matmul(x_hat_batch, A, name='y_hat_batch')

    # define all losses
    m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1)
    m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)
    zp_loss_batch = tf.reduce_sum(z_batch**2, 1)

    # define total loss
    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                     + hparams.mloss2_weight * m_loss2_batch \
                     + hparams.zprior_weight * zp_loss_batch
    total_loss = tf.reduce_mean(total_loss_batch)

    # Compute means for logging
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)

    # Set up gradient descent
    global_step = tf.Variable(0, trainable=False)
    learning_rate = utils.get_learning_rate(global_step, hparams)
    opt = utils.get_optimizer(learning_rate, hparams)
    update_op = opt.minimize(total_loss,
                             var_list=[z_batch],
                             global_step=global_step,
                             name='update_op')

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)

    def estimator(A_val, y_batch_val, hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)
        feed_dict = {A: A_val, y_batch: y_batch_val}
        for i in range(hparams.num_random_restarts):
            sess.run([z_batch.initializer])
            for j in range(hparams.max_update_iter):
                _, lr_val, total_loss_val, \
                m_loss1_val, \
                m_loss2_val, \
                zp_loss_val = sess.run([update_op, learning_rate, total_loss,
                                        m_loss1,
                                        m_loss2,
                                        zp_loss], feed_dict=feed_dict)
                logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}'
                print logging_format.format(i, j, lr_val, total_loss_val,
                                            m_loss1_val, m_loss2_val,
                                            zp_loss_val)

                if hparams.gif and ((j % hparams.gif_iter) == 0):
                    images = sess.run(x_hat_batch, feed_dict=feed_dict)
                    for im_num, image in enumerate(images):
                        save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num)
                        utils.set_up_dir(save_dir)
                        save_path = save_dir + '{0}.png'.format(j)
                        image = image.reshape(hparams.image_shape)
                        save_image(image, save_path)

            x_hat_batch_val, total_loss_batch_val = sess.run(
                [x_hat_batch, total_loss_batch], feed_dict=feed_dict)
            best_keeper.report(x_hat_batch_val, total_loss_batch_val)
        return best_keeper.get_best()

    return estimator
def vae_estimator(hparams):
    # Get a session
    sess = tf.Session()

    # Set up palceholders
    A = tf.placeholder(tf.float32,
                       shape=(hparams.n_input, hparams.num_measurements),
                       name='A')
    y_batch = tf.placeholder(tf.float32,
                             shape=(hparams.batch_size,
                                    hparams.num_measurements),
                             name='y_batch')

    # Create the generator
    # TODO: Move z_batch definition here
    z_batch, x_hat_batch, restore_path, restore_dict, _ = mnist_model_def.vae_gen(
        hparams)

    # measure the estimate
    if hparams.measurement_type == 'project':
        y_hat_batch = tf.identity(x_hat_batch, name='y_hat_batch')
    else:
        y_hat_batch = tf.matmul(x_hat_batch, A, name='y_hat_batch')

    # define all losses
    m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1)
    m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)

    #zp_loss_batch = tf.reduce_sum(z_batch**2, 1)
    if hparams.stdv > 0:
        norm_val = 1 / (hparams.stdv**2)
    else:
        norm_val = 1e+20

    zp_loss_batch = tf.reduce_sum(
        (z_batch - tf.ones(tf.shape(z_batch)) * hparams.mean)**2 * norm_val,
        1)  #added normalization

    # define total loss

    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                     + hparams.mloss2_weight * m_loss2_batch \
                     + hparams.zprior_weight * zp_loss_batch
    total_loss = tf.reduce_mean(total_loss_batch)

    # Compute means for logging
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)

    # Set up gradient descent
    var_list = [z_batch]
    global_step = tf.Variable(0, trainable=False, name='global_step')
    learning_rate = utils.get_learning_rate(global_step, hparams)
    opt = utils.get_optimizer(learning_rate, hparams)
    update_op = opt.minimize(total_loss,
                             var_list=var_list,
                             global_step=global_step,
                             name='update_op')
    opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)

    def estimator(A_val, y_batch_val, hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)
        if hparams.measurement_type == 'project':
            #            if y_batch_val.shape[0]!=hparams.batch_size:
            #                y_batch_val_tmp = np.zeros((hparams.batch_size,hparams.num_measurements))
            #                y_batch_val_tmp[:y_batch_val.shape[0],:] = y_batch_val
            #                y_batch_val = y_batch_val_tmp

            #                print('Smaller INPUT NUMBER')#Or change hparams on the fly
            feed_dict = {y_batch: y_batch_val}
        else:
            feed_dict = {A: A_val, y_batch: y_batch_val}
        for i in range(hparams.num_random_restarts):
            sess.run(opt_reinit_op)
            for j in range(hparams.max_update_iter):
                _, lr_val, total_loss_val, \
                m_loss1_val, \
                m_loss2_val, \
                zp_loss_val = sess.run([update_op, learning_rate, total_loss,
                                        m_loss1,
                                        m_loss2,
                                        zp_loss], feed_dict=feed_dict)
                logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}'
                print(
                    logging_format.format(i, j, lr_val, total_loss_val,
                                          m_loss1_val, m_loss2_val,
                                          zp_loss_val))
                #print('n_z is {}'.format(hparams.n_z))
                if total_loss_val == m_loss2_val and zp_loss_val > 0 and hparams.zprior_weight > 0:
                    raise ValueError('NONONO')

                if hparams.gif and ((j % hparams.gif_iter) == 0):
                    images = sess.run(x_hat_batch, feed_dict=feed_dict)
                    for im_num, image in enumerate(images):
                        save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num)
                        utils.set_up_dir(save_dir)
                        save_path = save_dir + '{0}.png'.format(j)
                        image = image.reshape(hparams.image_shape)
                        save_image(image, save_path)

            x_hat_batch_val, total_loss_batch_val = sess.run(
                [x_hat_batch, total_loss_batch], feed_dict=feed_dict)
            best_keeper.report(x_hat_batch_val, total_loss_batch_val)
        return best_keeper.get_best()

    return estimator