def conditioning_network_28(x,
                            masks,
                            nr_filters,
                            is_training=True,
                            nonlinearity=None,
                            bn=True,
                            kernel_initializer=None,
                            kernel_regularizer=None,
                            counters={}):
    name = get_name("conditioning_network_28", counters)
    x = x * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    xs = int_shape(x)
    x = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)
    with tf.variable_scope(name):
        with arg_scope([conv2d, residual_block, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            outputs = conv2d(x, nr_filters, 4, 1, "SAME")
            for l in range(4):
                outputs = conv2d(outputs, nr_filters, 4, 1, "SAME")
            outputs = conv2d(outputs,
                             nr_filters,
                             1,
                             1,
                             "SAME",
                             nonlinearity=None,
                             bn=False)
            return outputs
def conv_encoder_32(inputs,
                    z_dim,
                    nonlinearity=None,
                    bn=True,
                    kernel_initializer=None,
                    kernel_regularizer=None,
                    is_training=False,
                    counters={}):
    name = get_name("conv_encoder_32", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([conv2d, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            outputs = inputs
            outputs = conv2d(outputs, 64, 4, 2, "SAME")
            outputs = conv2d(outputs, 128, 4, 2, "SAME")
            outputs = conv2d(outputs, 128, 4, 2, "SAME")
            outputs = conv2d(outputs, 256, 4, 1, "VALID")
            outputs = tf.reshape(outputs, [-1, 256])
            z_mu = dense(outputs, z_dim, nonlinearity=None, bn=False)
            z_log_sigma_sq = dense(outputs, z_dim, nonlinearity=None, bn=False)
            return z_mu, z_log_sigma_sq
Ejemplo n.º 3
0
def omniglot_conv_encoder(inputs,
                          r_dim,
                          is_training,
                          nonlinearity=None,
                          bn=True,
                          kernel_initializer=None,
                          kernel_regularizer=None,
                          counters={}):
    name = get_name("omniglot_conv_encoder", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([conv2d, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training):
            outputs = inputs
            outputs = conv2d(outputs, 64, 3, 1, "SAME")
            outputs = conv2d(outputs, 64, 3, 2, "SAME")
            outputs = conv2d(outputs, 128, 3, 1, "SAME")
            outputs = conv2d(outputs, 128, 3, 2, "SAME")
            outputs = conv2d(outputs, 256, 4, 1, "VALID")
            outputs = conv2d(outputs, 256, 4, 1, "VALID")
            outputs = tf.reshape(outputs, [-1, 256])
            r = tf.dense(outputs, r_dim, nonlinearity=None, bn=False)
            return r