def q_net(observed, x, n_z, n_k, tau, n_particles, relaxed=False): with zs.BayesianNet(observed=observed) as variational: lz_x = tf.layers.dense(tf.to_float(x), 200, activation=tf.tanh) lz_x = tf.layers.dense(lz_x, 200, activation=tf.tanh) z_logits = tf.layers.dense(lz_x, n_z * n_k) z_stacked_logits = tf.reshape(z_logits, [-1, n_z, n_k]) if relaxed: z = zs.ExpConcrete('z', tau, z_stacked_logits, n_samples=n_particles, group_ndims=1) else: z = zs.OnehotCategorical('z', z_stacked_logits, n_samples=n_particles, group_ndims=1, dtype=tf.float32) return variational
def q_net(observed, x, n_z, n_k, tau, n_particles, relaxed=False): with zs.BayesianNet(observed=observed) as variational: lz_x = layers.fully_connected(tf.to_float(x), 200, activation_fn=tf.tanh) lz_x = layers.fully_connected(lz_x, 200, activation_fn=tf.tanh) z_logits = layers.fully_connected(lz_x, n_z * n_k, activation_fn=None) z_stacked_logits = tf.reshape(z_logits, [n, n_z, n_k]) if relaxed: z = zs.ExpConcrete('z', tau, z_stacked_logits, n_samples=n_particles, group_event_ndims=1) else: z = zs.OnehotCategorical('z', z_stacked_logits, dtype=tf.float32, n_samples=n_particles, group_event_ndims=1) return variational
def vae(observed, n, n_x, n_z, n_k, tau, n_particles, relaxed=False): with zs.BayesianNet(observed=observed) as model: z_stacked_logits = tf.zeros([n, n_z, n_k]) if relaxed: z = zs.ExpConcrete('z', tau, z_stacked_logits, n_samples=n_particles, group_event_ndims=1) z = tf.exp(tf.reshape(z, [n_particles, n, n_z * n_k])) else: z = zs.OnehotCategorical('z', z_stacked_logits, dtype=tf.float32, n_samples=n_particles, group_event_ndims=1) z = tf.reshape(z, [n_particles, n, n_z * n_k]) lx_z = layers.fully_connected(z, 200, activation_fn=tf.tanh) lx_z = layers.fully_connected(lx_z, 200, activation_fn=tf.tanh) x_logits = layers.fully_connected(lx_z, n_x, activation_fn=None) x = zs.Bernoulli('x', x_logits, group_event_ndims=1) return model