def q_net(observed, x, n_z, n_k, tau, n_particles, relaxed=False):
    with zs.BayesianNet(observed=observed) as variational:
        lz_x = tf.layers.dense(tf.to_float(x), 200, activation=tf.tanh)
        lz_x = tf.layers.dense(lz_x, 200, activation=tf.tanh)
        z_logits = tf.layers.dense(lz_x, n_z * n_k)
        z_stacked_logits = tf.reshape(z_logits, [-1, n_z, n_k])
        if relaxed:
            z = zs.ExpConcrete('z',
                               tau,
                               z_stacked_logits,
                               n_samples=n_particles,
                               group_ndims=1)
        else:
            z = zs.OnehotCategorical('z',
                                     z_stacked_logits,
                                     n_samples=n_particles,
                                     group_ndims=1,
                                     dtype=tf.float32)
    return variational
Beispiel #2
0
def q_net(observed, x, n_z, n_k, tau, n_particles, relaxed=False):
    with zs.BayesianNet(observed=observed) as variational:
        lz_x = layers.fully_connected(tf.to_float(x),
                                      200,
                                      activation_fn=tf.tanh)
        lz_x = layers.fully_connected(lz_x, 200, activation_fn=tf.tanh)
        z_logits = layers.fully_connected(lz_x, n_z * n_k, activation_fn=None)
        z_stacked_logits = tf.reshape(z_logits, [n, n_z, n_k])
        if relaxed:
            z = zs.ExpConcrete('z',
                               tau,
                               z_stacked_logits,
                               n_samples=n_particles,
                               group_event_ndims=1)
        else:
            z = zs.OnehotCategorical('z',
                                     z_stacked_logits,
                                     dtype=tf.float32,
                                     n_samples=n_particles,
                                     group_event_ndims=1)
    return variational
Beispiel #3
0
def vae(observed, n, n_x, n_z, n_k, tau, n_particles, relaxed=False):
    with zs.BayesianNet(observed=observed) as model:
        z_stacked_logits = tf.zeros([n, n_z, n_k])
        if relaxed:
            z = zs.ExpConcrete('z',
                               tau,
                               z_stacked_logits,
                               n_samples=n_particles,
                               group_event_ndims=1)
            z = tf.exp(tf.reshape(z, [n_particles, n, n_z * n_k]))
        else:
            z = zs.OnehotCategorical('z',
                                     z_stacked_logits,
                                     dtype=tf.float32,
                                     n_samples=n_particles,
                                     group_event_ndims=1)
            z = tf.reshape(z, [n_particles, n, n_z * n_k])
        lx_z = layers.fully_connected(z, 200, activation_fn=tf.tanh)
        lx_z = layers.fully_connected(lx_z, 200, activation_fn=tf.tanh)
        x_logits = layers.fully_connected(lx_z, n_x, activation_fn=None)
        x = zs.Bernoulli('x', x_logits, group_event_ndims=1)
    return model