def _transform_layer(self, layer_id, x, compute_y, compute_log_det): w, u, b, u_hat = \ self.get_layer_params(layer_id, ['w', 'u', 'b', 'u_hat']) # flatten x for better performance x, s1, s2 = flatten(x, 2) # x.shape == [?, n_units] wxb = tf.matmul(x, w, transpose_b=True) + b # shape == [?, 1] tanh_wxb = tf.tanh(wxb) # shape == [?, 1] # compute y = f(x) y = None if compute_y: y = x + u_hat * tanh_wxb # shape == [?, n_units] y = unflatten(y, s1, s2) # compute log(det|df/dz|) log_det = None if compute_log_det: grad = 1. - tf.square(tanh_wxb) # dtanh(x)/dx = 1 - tanh^2(x) phi = grad * w # shape == [?, n_units] u_phi = tf.matmul(phi, u_hat, transpose_b=True) # shape == [?, 1] det_jac = 1. + u_phi # shape == [?, 1] log_det = tf.log(tf.abs(det_jac)) # shape == [?, 1] log_det = unflatten(tf.squeeze(log_det, -1), s1, s2) # now returns the transformed sample and log-determinant return y, log_det
def q_net(config, x, observed=None, n_samples=None, is_training=True): net = BayesianNet(observed=observed) # compute the hidden features with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_x = tf.to_float(x) h_x = dense(h_x, 500) h_x = dense(h_x, 500) # sample y ~ q(y|x) y_logits = dense(h_x, config.n_clusters, name='y_logits') y = net.add('y', Categorical(y_logits), n_samples=n_samples) y_one_hot = tf.one_hot(y, config.n_clusters, dtype=tf.float32) # sample z ~ q(z|y,x) with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): if config.mean_field_assumption_for_q: # by mean-field-assumption we let q(z|y,x) = q(z|x) h_z, s1, s2 = flatten(h_x, 2) z_n_samples = n_samples else: if n_samples is not None: h_z = tf.concat([ tf.tile(tf.reshape(h_x, [1, -1, 500]), tf.stack([n_samples, 1, 1])), y_one_hot ], axis=-1) else: h_z = tf.concat([h_x, y_one_hot], axis=-1) h_z, s1, s2 = flatten(h_z, 2) h_z = dense(h_z, 500) z_n_samples = None z_mean = dense(h_z, config.z_dim, name='z_mean') z_logstd = dense(h_z, config.z_dim, name='z_logstd') z = net.add('z', Normal(mean=unflatten(z_mean, s1, s2), logstd=unflatten(z_logstd, s1, s2), is_reparameterized=False), n_samples=z_n_samples, group_ndims=1) return net
def reinforce_baseline_net(config, x): x, s1, s2 = flatten(tf.to_float(x), 2) with arg_scope([dense], kernel_regularizer=l2_regularizer(config.l2_reg), activation_fn=tf.nn.leaky_relu): h_x = dense(x, 500) h_x = unflatten(tf.reshape(dense(h_x, 1), [-1]), s1, s2) return h_x
def p_net(config, observed=None, n_z=None, is_training=True, channels_last=False): net = BayesianNet(observed=observed) # sample z ~ p(z) z = net.add('z', Normal(mean=tf.zeros([1, config.z_dim]), logstd=tf.zeros([1, config.z_dim])), group_ndims=1, n_samples=n_z) # compute the hidden features with arg_scope([deconv_resnet_block], shortcut_kernel_size=config.shortcut_kernel_size, activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg), channels_last=channels_last): h_z, s1, s2 = flatten(z, 2) h_z = tf.reshape(dense(h_z, 64 * 7 * 7), [-1, 7, 7, 64] if channels_last else [-1, 64, 7, 7]) h_z = deconv_resnet_block(h_z, 64) # output: (64, 7, 7) h_z = deconv_resnet_block(h_z, 32, strides=2) # output: (32, 14, 14) h_z = deconv_resnet_block(h_z, 32) # output: (32, 14, 14) h_z = deconv_resnet_block(h_z, 16, strides=2) # output: (16, 28, 28) h_z = conv2d(h_z, 1, (1, 1), padding='same', name='feature_map_to_pixel', channels_last=channels_last) # output: (1, 28, 28) h_z = tf.reshape(h_z, [-1, config.x_dim]) # sample x ~ p(x|z) x_logits = unflatten(h_z, s1, s2) x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1) return net
def p_net(config, observed=None, n_z=None, is_training=True): net = BayesianNet(observed=observed) # sample z ~ p(z) z = net.add('z', Normal(mean=tf.zeros([1, config.z_dim]), logstd=tf.zeros([1, config.z_dim])), group_ndims=1, n_samples=n_z) # compute the hidden features with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_z, s1, s2 = flatten(z, 2) h_z = dense(h_z, 500) h_z = dense(h_z, 500) # sample x ~ p(x|z) x_logits = unflatten(dense(h_z, config.x_dim, name='x_logits'), s1, s2) x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1) return net
def p_net(config, observed=None, n_y=None, n_z=None, is_training=True, n_samples=None): if n_samples is not None: warnings.warn('`n_samples` is deprecated, use `n_y` instead.') n_y = n_samples net = BayesianNet(observed=observed) # sample y y = net.add('y', Categorical(tf.zeros([1, config.n_clusters])), n_samples=n_y) # sample z ~ p(z|y) z = net.add('z', gaussian_mixture_prior(y, config.z_dim, config.n_clusters), group_ndims=1, n_samples=n_z, is_reparameterized=False) # compute the hidden features for x with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_x, s1, s2 = flatten(z, 2) h_x = dense(h_x, 500) h_x = dense(h_x, 500) # sample x ~ p(x|z) x_logits = unflatten(dense(h_x, config.x_dim, name='x_logits'), s1, s2) x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1) return net