Ejemplo n.º 1
0
    def _transform_layer(self, layer_id, x, compute_y, compute_log_det):
        w, u, b, u_hat = \
            self.get_layer_params(layer_id, ['w', 'u', 'b', 'u_hat'])

        # flatten x for better performance
        x, s1, s2 = flatten(x, 2)  # x.shape == [?, n_units]
        wxb = tf.matmul(x, w, transpose_b=True) + b  # shape == [?, 1]
        tanh_wxb = tf.tanh(wxb)  # shape == [?, 1]

        # compute y = f(x)
        y = None
        if compute_y:
            y = x + u_hat * tanh_wxb  # shape == [?, n_units]
            y = unflatten(y, s1, s2)

        # compute log(det|df/dz|)
        log_det = None
        if compute_log_det:
            grad = 1. - tf.square(tanh_wxb)  # dtanh(x)/dx = 1 - tanh^2(x)
            phi = grad * w  # shape == [?, n_units]
            u_phi = tf.matmul(phi, u_hat, transpose_b=True)  # shape == [?, 1]
            det_jac = 1. + u_phi  # shape == [?, 1]
            log_det = tf.log(tf.abs(det_jac))  # shape == [?, 1]
            log_det = unflatten(tf.squeeze(log_det, -1), s1, s2)

        # now returns the transformed sample and log-determinant
        return y, log_det
Ejemplo n.º 2
0
def q_net(config, x, observed=None, n_samples=None, is_training=True):
    net = BayesianNet(observed=observed)

    # compute the hidden features
    with arg_scope([dense],
                   activation_fn=tf.nn.leaky_relu,
                   kernel_regularizer=l2_regularizer(config.l2_reg)):
        h_x = tf.to_float(x)
        h_x = dense(h_x, 500)
        h_x = dense(h_x, 500)

    # sample y ~ q(y|x)
    y_logits = dense(h_x, config.n_clusters, name='y_logits')
    y = net.add('y', Categorical(y_logits), n_samples=n_samples)
    y_one_hot = tf.one_hot(y, config.n_clusters, dtype=tf.float32)

    # sample z ~ q(z|y,x)
    with arg_scope([dense],
                   activation_fn=tf.nn.leaky_relu,
                   kernel_regularizer=l2_regularizer(config.l2_reg)):
        if config.mean_field_assumption_for_q:
            # by mean-field-assumption we let q(z|y,x) = q(z|x)
            h_z, s1, s2 = flatten(h_x, 2)
            z_n_samples = n_samples
        else:
            if n_samples is not None:
                h_z = tf.concat([
                    tf.tile(tf.reshape(h_x, [1, -1, 500]),
                            tf.stack([n_samples, 1, 1])), y_one_hot
                ],
                                axis=-1)
            else:
                h_z = tf.concat([h_x, y_one_hot], axis=-1)
            h_z, s1, s2 = flatten(h_z, 2)
            h_z = dense(h_z, 500)
            z_n_samples = None

    z_mean = dense(h_z, config.z_dim, name='z_mean')
    z_logstd = dense(h_z, config.z_dim, name='z_logstd')
    z = net.add('z',
                Normal(mean=unflatten(z_mean, s1, s2),
                       logstd=unflatten(z_logstd, s1, s2),
                       is_reparameterized=False),
                n_samples=z_n_samples,
                group_ndims=1)

    return net
Ejemplo n.º 3
0
def reinforce_baseline_net(config, x):
    x, s1, s2 = flatten(tf.to_float(x), 2)
    with arg_scope([dense],
                   kernel_regularizer=l2_regularizer(config.l2_reg),
                   activation_fn=tf.nn.leaky_relu):
        h_x = dense(x, 500)
    h_x = unflatten(tf.reshape(dense(h_x, 1), [-1]), s1, s2)
    return h_x
Ejemplo n.º 4
0
def p_net(config,
          observed=None,
          n_z=None,
          is_training=True,
          channels_last=False):
    net = BayesianNet(observed=observed)

    # sample z ~ p(z)
    z = net.add('z',
                Normal(mean=tf.zeros([1, config.z_dim]),
                       logstd=tf.zeros([1, config.z_dim])),
                group_ndims=1,
                n_samples=n_z)

    # compute the hidden features
    with arg_scope([deconv_resnet_block],
                   shortcut_kernel_size=config.shortcut_kernel_size,
                   activation_fn=tf.nn.leaky_relu,
                   kernel_regularizer=l2_regularizer(config.l2_reg),
                   channels_last=channels_last):
        h_z, s1, s2 = flatten(z, 2)
        h_z = tf.reshape(dense(h_z, 64 * 7 * 7),
                         [-1, 7, 7, 64] if channels_last else [-1, 64, 7, 7])
        h_z = deconv_resnet_block(h_z, 64)  # output: (64, 7, 7)
        h_z = deconv_resnet_block(h_z, 32, strides=2)  # output: (32, 14, 14)
        h_z = deconv_resnet_block(h_z, 32)  # output: (32, 14, 14)
        h_z = deconv_resnet_block(h_z, 16, strides=2)  # output: (16, 28, 28)
    h_z = conv2d(h_z,
                 1, (1, 1),
                 padding='same',
                 name='feature_map_to_pixel',
                 channels_last=channels_last)  # output: (1, 28, 28)
    h_z = tf.reshape(h_z, [-1, config.x_dim])

    # sample x ~ p(x|z)
    x_logits = unflatten(h_z, s1, s2)
    x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1)

    return net
Ejemplo n.º 5
0
def p_net(config, observed=None, n_z=None, is_training=True):
    net = BayesianNet(observed=observed)

    # sample z ~ p(z)
    z = net.add('z',
                Normal(mean=tf.zeros([1, config.z_dim]),
                       logstd=tf.zeros([1, config.z_dim])),
                group_ndims=1,
                n_samples=n_z)

    # compute the hidden features
    with arg_scope([dense],
                   activation_fn=tf.nn.leaky_relu,
                   kernel_regularizer=l2_regularizer(config.l2_reg)):
        h_z, s1, s2 = flatten(z, 2)
        h_z = dense(h_z, 500)
        h_z = dense(h_z, 500)

    # sample x ~ p(x|z)
    x_logits = unflatten(dense(h_z, config.x_dim, name='x_logits'), s1, s2)
    x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1)

    return net
Ejemplo n.º 6
0
def p_net(config,
          observed=None,
          n_y=None,
          n_z=None,
          is_training=True,
          n_samples=None):
    if n_samples is not None:
        warnings.warn('`n_samples` is deprecated, use `n_y` instead.')
        n_y = n_samples

    net = BayesianNet(observed=observed)

    # sample y
    y = net.add('y',
                Categorical(tf.zeros([1, config.n_clusters])),
                n_samples=n_y)

    # sample z ~ p(z|y)
    z = net.add('z',
                gaussian_mixture_prior(y, config.z_dim, config.n_clusters),
                group_ndims=1,
                n_samples=n_z,
                is_reparameterized=False)

    # compute the hidden features for x
    with arg_scope([dense],
                   activation_fn=tf.nn.leaky_relu,
                   kernel_regularizer=l2_regularizer(config.l2_reg)):
        h_x, s1, s2 = flatten(z, 2)
        h_x = dense(h_x, 500)
        h_x = dense(h_x, 500)

    # sample x ~ p(x|z)
    x_logits = unflatten(dense(h_x, config.x_dim, name='x_logits'), s1, s2)
    x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1)

    return net