Exemplo n.º 1
0
def p_net(observed=None, n_z=None, is_initializing=False):
    net = spt.BayesianNet(observed=observed)
    normalizer_fn = functools.partial(spt.layers.act_norm,
                                      initializing=is_initializing)

    # sample z ~ p(z)
    z = net.add('z',
                spt.Normal(mean=tf.zeros([1, config.z_dim]),
                           logstd=tf.zeros([1, config.z_dim])),
                group_ndims=1,
                n_samples=n_z)

    # compute the hidden features
    with arg_scope([spt.layers.dense],
                   activation_fn=tf.nn.leaky_relu,
                   normalizer_fn=normalizer_fn,
                   weight_norm=True,
                   kernel_regularizer=spt.layers.l2_regularizer(
                       config.l2_reg)):
        h_z = z
        h_z = spt.layers.dense(h_z, 500)
        h_z = spt.layers.dense(h_z, 500)

    # sample x ~ p(x|z)
    x_logits = spt.layers.dense(h_z, config.x_dim, name='x_logits')
    x = net.add('x', spt.Bernoulli(logits=x_logits), group_ndims=1)

    return net
Exemplo n.º 2
0
def p_net(observed=None, n_y=None, n_z=None, n_samples=None):
    if n_samples is not None:
        warnings.warn('`n_samples` is deprecated, use `n_y` instead.')
        n_y = n_samples

    net = spt.BayesianNet(observed=observed)

    # sample y
    y = net.add('y',
                spt.Categorical(tf.zeros([1, config.n_clusters])),
                n_samples=n_y)

    # sample z ~ p(z|y)
    z = net.add('z',
                gaussian_mixture_prior(y, config.z_dim, config.n_clusters),
                group_ndims=1,
                n_samples=n_z,
                is_reparameterized=False)

    # compute the hidden features for x
    with arg_scope([spt.layers.dense],
                   activation_fn=tf.nn.leaky_relu,
                   kernel_regularizer=spt.layers.l2_regularizer(
                       config.l2_reg)):
        h_x = z
        h_x = spt.layers.dense(h_x, 500)
        h_x = spt.layers.dense(h_x, 500)

    # sample x ~ p(x|z)
    x_logits = spt.layers.dense(h_x, config.x_dim, name='x_logits')
    x = net.add('x', spt.Bernoulli(logits=x_logits), group_ndims=1)

    return net
Exemplo n.º 3
0
def p_net(observed=None, n_z=None):
    net = spt.BayesianNet(observed=observed)

    # sample z ~ p(z)
    z = net.add('z', spt.Bernoulli(tf.zeros([1, config.z_dim])),
                group_ndims=1, n_samples=n_z)

    # compute the hidden features
    with arg_scope([spt.layers.dense],
                   activation_fn=tf.nn.leaky_relu,
                   kernel_regularizer=spt.layers.l2_regularizer(config.l2_reg)):
        h_z = tf.to_float(z)
        h_z = spt.layers.dense(h_z, 500)
        h_z = spt.layers.dense(h_z, 500)

    # sample x ~ p(x|z)
    x_logits = spt.layers.dense(h_z, config.x_dim, name='x_logits')
    x = net.add('x', spt.Bernoulli(logits=x_logits), group_ndims=1)

    return net
Exemplo n.º 4
0
def p_net(observed=None, n_z=None, is_training=False, is_initializing=False):
    net = spt.BayesianNet(observed=observed)

    normalizer_fn = None if not config.act_norm else functools.partial(
        spt.layers.act_norm,
        axis=-1 if config.channels_last else -3,
        initializing=is_initializing,
        value_ndims=3,
    )

    # sample z ~ p(z)
    z = net.add('z',
                spt.Normal(mean=tf.zeros([1, config.z_dim]),
                           std=tf.ones([1, config.z_dim]) *
                           config.truncated_sigma),
                group_ndims=1,
                n_samples=n_z)

    # compute the hidden features
    with arg_scope([spt.layers.resnet_deconv2d_block],
                   kernel_size=config.kernel_size,
                   shortcut_kernel_size=config.shortcut_kernel_size,
                   activation_fn=tf.nn.leaky_relu,
                   normalizer_fn=normalizer_fn,
                   kernel_regularizer=spt.layers.l2_regularizer(config.l2_reg),
                   channels_last=config.channels_last):
        h_z = spt.layers.dense(z, 64 * 7 * 7)
        h_z = spt.ops.reshape_tail(h_z,
                                   ndims=1,
                                   shape=(7, 7,
                                          64) if config.channels_last else
                                   (64, 7, 7))
        h_z = spt.layers.resnet_deconv2d_block(h_z, 64)  # output: (64, 7, 7)
        h_z = spt.layers.resnet_deconv2d_block(
            h_z, 32, strides=2)  # output: (32, 14, 14)
        h_z = spt.layers.resnet_deconv2d_block(h_z, 32)  # output: (32, 14, 14)
        h_z = spt.layers.resnet_deconv2d_block(
            h_z, 16, strides=2)  # output: (16, 28, 28)

    # sample x ~ p(x|z)
    x_logits = spt.layers.conv2d(
        h_z,
        1, (1, 1),
        padding='same',
        name='feature_map_to_pixel',
        channels_last=config.channels_last)  # output: (1, 28, 28)
    x = net.add('x',
                spt.Bernoulli(logits=x_logits, dtype=tf.float32),
                group_ndims=3)

    return net
Exemplo n.º 5
0
def q_net(x, observed=None, n_z=None):
    net = spt.BayesianNet(observed=observed)

    # compute the hidden features
    with arg_scope([spt.layers.dense],
                   activation_fn=tf.nn.leaky_relu,
                   kernel_regularizer=spt.layers.l2_regularizer(config.l2_reg)):
        h_x = tf.to_float(x)
        h_x = spt.layers.dense(h_x, 500)
        h_x = spt.layers.dense(h_x, 500)

    # sample z ~ q(z|x)
    z_logits = spt.layers.dense(h_x, config.z_dim, name='z_logits')
    z = net.add('z', spt.Bernoulli(logits=z_logits), n_samples=n_z,
                group_ndims=1)

    return net
Exemplo n.º 6
0
def p_net(observed=None, n_z=None, is_initializing=False):
    net = spt.BayesianNet(observed=observed)
    normalizer_fn = functools.partial(spt.layers.act_norm,
                                      initializing=is_initializing)

    # sample z ~ p(z)
    def make_component(i):
        normal = spt.Normal(mean=tf.get_variable('mean_{}'.format(i),
                                                 shape=[1, config.z_dim],
                                                 dtype=tf.float32,
                                                 trainable=True),
                            logstd=tf.maximum(
                                tf.get_variable('logstd_{}'.format(i),
                                                shape=[1, config.z_dim],
                                                dtype=tf.float32,
                                                trainable=True),
                                config.z_logstd_min))
        return normal.expand_value_ndims(1)

    components = [
        make_component(i) for i in range(config.n_mixture_components)
    ]
    mixture = spt.Mixture(categorical=spt.Categorical(
        logits=tf.zeros([1, config.n_mixture_components])),
                          components=components,
                          is_reparameterized=True)

    z = net.add('z', mixture, n_samples=n_z)

    # compute the hidden features
    with arg_scope([spt.layers.dense],
                   activation_fn=tf.nn.leaky_relu,
                   normalizer_fn=normalizer_fn,
                   weight_norm=True,
                   kernel_regularizer=spt.layers.l2_regularizer(
                       config.l2_reg)):
        h_z = z
        h_z = spt.layers.dense(h_z, 500)
        h_z = spt.layers.dense(h_z, 500)

    # sample x ~ p(x|z)
    x_logits = spt.layers.dense(h_z, config.x_dim, name='x_logits')
    x = net.add('x', spt.Bernoulli(logits=x_logits), group_ndims=1)

    return net
Exemplo n.º 7
0
def p_net(observed=None, n_z=None, is_training=True, channels_last=True):
    net = spt.BayesianNet(observed=observed)

    # sample z ~ p(z)
    z = net.add('z',
                spt.Normal(mean=tf.zeros([1, config.z_dim]),
                           logstd=tf.zeros([1, config.z_dim])),
                group_ndims=1,
                n_samples=n_z)

    # compute the hidden features
    with arg_scope([spt.layers.resnet_deconv2d_block],
                   kernel_size=config.kernel_size,
                   shortcut_kernel_size=config.shortcut_kernel_size,
                   activation_fn=tf.nn.leaky_relu,
                   kernel_regularizer=spt.layers.l2_regularizer(config.l2_reg),
                   channels_last=channels_last):
        h_z = spt.layers.dense(z, 64 * 7 * 7)
        h_z = spt.utils.reshape_tail(
            h_z, ndims=1, shape=[7, 7, 64] if channels_last else [64, 7, 7])
        h_z = spt.layers.resnet_deconv2d_block(h_z, 64)  # output: (64, 7, 7)
        h_z = spt.layers.resnet_deconv2d_block(
            h_z, 32, strides=2)  # output: (32, 14, 14)
        h_z = spt.layers.resnet_deconv2d_block(h_z, 32)  # output: (32, 14, 14)
        h_z = spt.layers.resnet_deconv2d_block(
            h_z, 16, strides=2)  # output: (16, 28, 28)

    # sample x ~ p(x|z)
    h_z = spt.layers.conv2d(h_z,
                            1, (1, 1),
                            padding='same',
                            name='feature_map_to_pixel',
                            channels_last=channels_last)  # output: (1, 28, 28)
    x_logits = spt.utils.reshape_tail(h_z, 3, [config.x_dim])
    x = net.add('x', spt.Bernoulli(logits=x_logits), group_ndims=1)

    return net
Exemplo n.º 8
0
def p_net(observed=None, n_z=None, beta=1.0, mcmc_iterator=0, log_Z=0.0, initial_z=None,
          mcmc_alpha=config.smallest_step):
    net = spt.BayesianNet(observed=observed)
    # sample z ~ p(z)
    normal = spt.Normal(mean=tf.zeros([1, config.z_dim]),
                        logstd=tf.zeros([1, config.z_dim]))
    normal = normal.batch_ndims_to_value(1)
    xi = tf.get_variable(name='xi', shape=(), initializer=tf.constant_initializer(config.initial_xi),
                         dtype=tf.float32, trainable=True)
    # xi = tf.square(xi)
    xi = tf.nn.sigmoid(xi)  # TODO
    pz = EnergyDistribution(normal, G=G_theta, D=D_psi, log_Z=log_Z, xi=xi, mcmc_iterator=mcmc_iterator,
                            initial_z=initial_z, mcmc_alpha=mcmc_alpha)
    z = net.add('z', pz, n_samples=n_z)
    x_mean = G_theta(z)
    x_mean = tf.clip_by_value(x_mean, 1e-7, 1 - 1e-7)
    logits = tf.log(x_mean) - tf.log1p(-x_mean)
    bernouli = spt.Bernoulli(
        logits=logits, dtype=tf.float32
    )
    # bernouli.mean = x_mean
    x = net.add('x', bernouli, group_ndims=3)

    return net