예제 #1
0
def px_graph(z, y):
    reuse = len(tf.get_collection(tf.GraphKeys.VARIABLES, scope='px')) > 0
    # -- transform z to be a sample from one of the Gaussian mixture components
    with tf.variable_scope('z_transform'):
        zm = Dense(y, 64, 'zm', reuse=reuse)
        zv = Dense(y, 64, 'zv', tf.nn.softplus, reuse=reuse)
    # -- p(x)
    with tf.variable_scope('px'):
        with tf.name_scope('layer1'):
            zy = zm + tf.sqrt(zv) * z
            h1 = custom_layer(zy, reuse)
        h2 = Dense(h1, 512, 'layer2', tf.nn.relu, reuse=reuse)
        h3 = Dense(h2, 512, 'layer3', tf.nn.relu, reuse=reuse)
        h4 = Dense(h3, 512, 'layer4', tf.nn.relu, reuse=reuse)
        px_logit = Dense(h2, 784, 'logit', reuse=reuse)
    return px_logit
예제 #2
0
def px_graph(z, y):
    reuse = len(tf.get_collection(tf.GraphKeys.VARIABLES, scope='px')) > 0
    # -- p(z)
    with tf.variable_scope('pz'):
        h1 = Dense(y, 128, 'h1', tf.nn.relu, reuse=reuse)
        h2 = Dense(h1, 128, 'h2', tf.nn.relu, reuse=reuse)
        zm = Dense(h2, 64, 'zm', reuse=reuse)
        zv = Dense(h2, 64, 'zv', tf.nn.softplus, reuse=reuse)
    # -- p(x)
    with tf.variable_scope('px'):
        # h1 = Dense(z, 512, 'layer1', tf.nn.relu, reuse=reuse)
        # h2 = Dense(h1, 512, 'layer2', tf.nn.relu, reuse=reuse)
        # h3 = Dense(h2, 512, 'layer3', tf.nn.relu, reuse=reuse)
        # px_logit = Dense(h3, 784, 'logit', reuse=reuse)
        h1 = Dense(z, 512, 'layer1', tf.nn.relu, reuse=reuse)
        h2 = Dense(h1, 28 * 14 * 14, 'layer2', tf.nn.relu, reuse=reuse)
        h2 = tf.reshape(h2, [-1, 14, 14, 28])
        conv1 = tf.layers.conv2d_transpose(h2,
                                           28, [3, 3], (1, 1),
                                           padding="same",
                                           activation=tf.nn.relu,
                                           reuse=reuse)
        conv2 = tf.layers.conv2d_transpose(conv1,
                                           28, [3, 3], (1, 1),
                                           padding="same",
                                           activation=tf.nn.relu,
                                           reuse=reuse)
        conv3 = tf.layers.conv2d_transpose(conv2,
                                           28, [3, 3], (2, 2),
                                           padding="same",
                                           activation=tf.nn.relu,
                                           reuse=reuse)
        conv4 = Conv2d(conv3,
                       1, [3, 3], [1, 1],
                       activation=tf.nn.relu,
                       reuse=reuse,
                       scope='convlayer3')
        px_logit = tf.contrib.layers.flatten(conv4)
    return zm, zv, px_logit