def q_net(observed, x_dim, z_dim, n_z_per_x): with zs.BayesianNet(observed=observed) as variational: x = zs.Empirical('x', tf.int32, (None, x_dim)) y = zs.Empirical('y', tf.int32, (None, 10)) x = tf.concat([x, y], 1) lz_x = tf.layers.dense(tf.to_float(x), 500, activation=tf.nn.relu) lz_x = tf.layers.dense(lz_x, 500, activation=tf.nn.relu) z_mean = tf.layers.dense(lz_x, z_dim) z_logstd = tf.layers.dense(lz_x, z_dim) z = zs.Normal('z', z_mean, logstd=z_logstd, group_ndims=1, n_samples=n_z_per_x) return variational
def cvae(observed, x_dim, y_dim, z_dim, n, n_particles=1): with zs.BayesianNet(observed=observed) as model: y = zs.Empirical('y', tf.int32, (n, y_dim)) z_mean = tf.zeros([n, z_dim]) z = zs.Normal('z', z_mean, std=1., group_ndims=1, n_samples=n_particles) z = tf.to_float(z[0]) yz = tf.concat([tf.to_float(y), z], axis=1) lx_yz = tf.layers.dense(tf.to_float(yz), 500, activation=tf.nn.relu) lx_yz = tf.layers.dense(lx_yz, 500, activation=tf.nn.relu) x_logits = tf.layers.dense(lx_yz, x_dim) x_mean = zs.Implicit('x_mean', tf.sigmoid(x_logits), group_ndims=1) x = zs.Bernoulli('x', logits=x_logits, group_ndims=1) return model