Ejemplo n.º 1
0
def conv2d_openai(x,
                  num_filters,
                  filter_size=[3, 3],
                  stride=[1, 1],
                  pad='SAME',
                  nonlinearity=None,
                  kernel_initializer=None,
                  init_scale=1.,
                  counters={},
                  init=False,
                  ema=None,
                  **kwargs):
    name = get_name("conv2d", counters)
    with tf.variable_scope(name):
        V = tf.get_variable('V',
                            shape=filter_size +
                            [int(x.get_shape()[-1]), num_filters],
                            dtype=tf.float32,
                            initializer=kernel_initializer,
                            trainable=True)
        b = tf.get_variable('b',
                            shape=[num_filters],
                            dtype=tf.float32,
                            initializer=tf.constant_initializer(0.),
                            trainable=True)
        W = V
        x = tf.nn.bias_add(tf.nn.conv2d(x, W, [1] + stride + [1], pad), b)
        return x
Ejemplo n.º 2
0
def gated_resnet(x,
                 a=None,
                 gh=None,
                 sh=None,
                 nonlinearity=tf.nn.elu,
                 conv=conv2d,
                 dropout_p=0.0,
                 counters={},
                 **kwargs):
    name = get_name("gated_resnet", counters)
    print("construct", name, "...")
    xs = int_shape(x)
    num_filters = xs[-1]
    kwargs["counters"] = counters
    with arg_scope([conv], **kwargs):
        c1 = conv(nonlinearity(x), num_filters)
        if a is not None:  # add short-cut connection if auxiliary input 'a' is given
            c1 += nin(nonlinearity(a), num_filters)
        c1 = nonlinearity(c1)
        c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
        c2 = conv(c1, num_filters * 2)
        # add projection of h vector if included: conditional generation
        if sh is not None:
            c2 += nin(sh, 2 * num_filters, nonlinearity=nonlinearity)
        if gh is not None:  # haven't finished this part
            pass
        a, b = tf.split(c2, 2, 3)
        c3 = a * tf.nn.sigmoid(b)
        return x + c3
def conv_decoder_28_binary(inputs,
                           nonlinearity=None,
                           bn=True,
                           kernel_initializer=None,
                           kernel_regularizer=None,
                           is_training=False,
                           counters={}):
    name = get_name("conv_decoder_28_binary", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([deconv2d, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            outputs = dense(inputs, 128)
            outputs = tf.reshape(outputs, [-1, 1, 1, 128])
            outputs = deconv2d(outputs, 128, 4, 1, "VALID")
            outputs = deconv2d(outputs, 64, 4, 1, "VALID")
            outputs = deconv2d(outputs, 64, 4, 2, "SAME")
            outputs = deconv2d(outputs, 32, 4, 2, "SAME")
            outputs = deconv2d(outputs,
                               1,
                               1,
                               1,
                               "SAME",
                               nonlinearity=None,
                               bn=False)
            return outputs
Ejemplo n.º 4
0
def aggregator(r,
               num_c,
               z_dim,
               method=tf.reduce_mean,
               nonlinearity=None,
               bn=True,
               kernel_initializer=None,
               kernel_regularizer=None,
               is_training=False,
               counters={}):
    name = get_name("aggregator", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            r_pr = method(r[:num_c], axis=0, keepdims=True)
            r = method(r, axis=0, keepdims=True)
            r = tf.concat([r_pr, r], axis=0)
            size = 256
            r = dense(r, size)
            r = dense(r, size)
            r = dense(r, size)
            z_mu = dense(r, z_dim, nonlinearity=None, bn=False)
            z_log_sigma_sq = dense(r, z_dim, nonlinearity=None, bn=False)
            return z_mu[:1], z_log_sigma_sq[:1], z_mu[1:], z_log_sigma_sq[1:]
Ejemplo n.º 5
0
def mix_logistic_sampler(params, nr_logistic_mix=10, sample_range=3., counters={}):
    # sample from [loc - sample_range*scale, loc + sample_range*scale]
    name = get_name("mix_logistic_sampler", counters)
    print("construct", name, "...")
    epsilon = 1. / ( tf.exp(float(sample_range))+1. )
    x = sample_from_discretized_mix_logistic(params, nr_logistic_mix, epsilon)
    return x
def conv_encoder_32_large(inputs,
                          z_dim,
                          nonlinearity=None,
                          bn=True,
                          kernel_initializer=None,
                          kernel_regularizer=None,
                          is_training=False,
                          counters={}):
    name = get_name("conv_encoder_32_large", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([conv2d, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            outputs = inputs
            outputs = conv2d(outputs, 32, 1, 1, "SAME")
            outputs = conv2d(outputs, 32, 1, 1, "SAME")
            outputs = conv2d(outputs, 64, 4, 2, "SAME")
            outputs = conv2d(outputs, 128, 4, 2, "SAME")
            outputs = conv2d(outputs, 256, 4, 2, "SAME")
            outputs = conv2d(outputs, 512, 4, 1, "VALID")
            outputs = tf.reshape(outputs, [-1, 512])
            z_mu = dense(outputs, z_dim, nonlinearity=None, bn=False)
            z_log_sigma_sq = dense(outputs, z_dim, nonlinearity=None, bn=False)
            return z_mu, z_log_sigma_sq
def conv_decoder_32_large(inputs,
                          output_features=False,
                          nonlinearity=None,
                          bn=True,
                          kernel_initializer=None,
                          kernel_regularizer=None,
                          is_training=False,
                          counters={}):
    name = get_name("conv_decoder_32_large", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([deconv2d, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            outputs = dense(inputs, 512)
            outputs = tf.reshape(outputs, [-1, 1, 1, 512])
            outputs = deconv2d(outputs, 256, 4, 1, "VALID")
            outputs = deconv2d(outputs, 128, 4, 2, "SAME")
            outputs = deconv2d(outputs, 64, 4, 2, "SAME")
            outputs = deconv2d(outputs, 32, 4, 2, "SAME")
            if output_features:
                return outputs
            outputs = deconv2d(outputs,
                               3,
                               1,
                               1,
                               "SAME",
                               nonlinearity=tf.sigmoid,
                               bn=False)
            outputs = 2. * outputs - 1.
            return outputs
Ejemplo n.º 8
0
def deconv2d(inputs, num_filters, kernel_size, strides=1, padding='SAME', nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}):
#def deconv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
    filter_size = kernel_size
    pad = padding
    x = inputs
    stride = strides
    xs = int_shape(x)
    name = get_name('deconv2d', counters)
    if pad=='SAME':
        target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters]
    else:
        target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters]
    with tf.variable_scope(name):
        V = tf.get_variable('V', shape=filter_size+[num_filters,int(x.get_shape()[-1])], dtype=tf.float32,
                              initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
        g = tf.get_variable('g', shape=[num_filters], dtype=tf.float32,
                              initializer=tf.constant_initializer(1.), trainable=True)
        b = tf.get_variable('b', shape=[num_filters], dtype=tf.float32,
                              initializer=tf.constant_initializer(0.), trainable=True)

        # use weight normalization (Salimans & Kingma, 2016)
        W = tf.reshape(g, [1, 1, num_filters, 1]) * tf.nn.l2_normalize(V, [0, 1, 3])

        # calculate convolutional layer output
        x = tf.nn.conv2d_transpose(x, W, target_shape, [1] + stride + [1], padding=pad)
        x = tf.nn.bias_add(x, b)

        outputs = x

        if bn:
            outputs = tf.layers.batch_normalization(outputs, training=is_training)
        if nonlinearity is not None:
            outputs = nonlinearity(outputs)
        print("    + deconv2d", int_shape(inputs), int_shape(outputs), nonlinearity, bn)
        return outputs
Ejemplo n.º 9
0
def reverse_pixel_cnn_28_binary(x,
                                masks,
                                context=None,
                                nr_logistic_mix=10,
                                nr_resnet=1,
                                nr_filters=100,
                                dropout_p=0.0,
                                nonlinearity=None,
                                bn=True,
                                kernel_initializer=None,
                                kernel_regularizer=None,
                                is_training=False,
                                counters={}):
    name = get_name("reverse_pixel_cnn_28_binary", counters)
    x = x * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       gh=None,
                       sh=context,
                       nonlinearity=nonlinearity,
                       dropout_p=dropout_p):
            with arg_scope([
                    gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d,
                    up_shifted_deconv2d, up_left_shifted_deconv2d
            ],
                           bn=bn,
                           kernel_initializer=kernel_initializer,
                           kernel_regularizer=kernel_regularizer,
                           is_training=is_training,
                           counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))

                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                return x_out
Ejemplo n.º 10
0
def fc_encoder(X,
               y,
               r_dim,
               nonlinearity=None,
               bn=True,
               kernel_initializer=None,
               kernel_regularizer=None,
               is_training=False,
               counters={}):
    inputs = tf.concat([X, y[:, None]], axis=1)
    name = get_name("fc_encoder", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            size = 256
            outputs = dense(inputs, size)
            outputs = nonlinearity(
                dense(outputs, size, nonlinearity=None) +
                dense(inputs, size, nonlinearity=None))
            inputs = outputs
            outputs = dense(outputs, size)
            outputs = nonlinearity(
                dense(outputs, size, nonlinearity=None) +
                dense(inputs, size, nonlinearity=None))
            outputs = dense(outputs, size)
            outputs = dense(outputs, r_dim, nonlinearity=None, bn=False)
            return outputs
def conditioning_network_28(x,
                            masks,
                            nr_filters,
                            is_training=True,
                            nonlinearity=None,
                            bn=True,
                            kernel_initializer=None,
                            kernel_regularizer=None,
                            counters={}):
    name = get_name("conditioning_network_28", counters)
    x = x * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    xs = int_shape(x)
    x = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)
    with tf.variable_scope(name):
        with arg_scope([conv2d, residual_block, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training,
                       counters=counters):
            outputs = conv2d(x, nr_filters, 4, 1, "SAME")
            for l in range(4):
                outputs = conv2d(outputs, nr_filters, 4, 1, "SAME")
            outputs = conv2d(outputs,
                             nr_filters,
                             1,
                             1,
                             "SAME",
                             nonlinearity=None,
                             bn=False)
            return outputs
Ejemplo n.º 12
0
def conditional_decoder(x,
                        z,
                        nonlinearity=None,
                        bn=True,
                        kernel_initializer=None,
                        kernel_regularizer=None,
                        is_training=False,
                        counters={}):
    name = get_name("conditional_decoder", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training):
            size = 256
            batch_size = tf.shape(x)[0]
            x = tf.tile(x, tf.stack([1, int_shape(z)[1]]))
            z = tf.tile(z, tf.stack([batch_size, 1]))
            # xz = x + z * tf.get_variable(name="coeff", shape=(), dtype=tf.float32, initializer=tf.constant_initializer(2.0))
            xz = x
            a = dense(xz, size, nonlinearity=None) + dense(
                z, size, nonlinearity=None)
            outputs = tf.nn.tanh(a) * tf.sigmoid(a)

            for k in range(4):
                a = dense(outputs, size, nonlinearity=None) + dense(
                    z, size, nonlinearity=None)
                outputs = tf.nn.tanh(a) * tf.sigmoid(a)
            outputs = dense(outputs, 1, nonlinearity=None, bn=False)
            outputs = tf.reshape(outputs, shape=(batch_size, ))
            return outputs
Ejemplo n.º 13
0
def omniglot_conv_encoder(inputs,
                          r_dim,
                          is_training,
                          nonlinearity=None,
                          bn=True,
                          kernel_initializer=None,
                          kernel_regularizer=None,
                          counters={}):
    name = get_name("omniglot_conv_encoder", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        with arg_scope([conv2d, dense],
                       nonlinearity=nonlinearity,
                       bn=bn,
                       kernel_initializer=kernel_initializer,
                       kernel_regularizer=kernel_regularizer,
                       is_training=is_training):
            outputs = inputs
            outputs = conv2d(outputs, 64, 3, 1, "SAME")
            outputs = conv2d(outputs, 64, 3, 2, "SAME")
            outputs = conv2d(outputs, 128, 3, 1, "SAME")
            outputs = conv2d(outputs, 128, 3, 2, "SAME")
            outputs = conv2d(outputs, 256, 4, 1, "VALID")
            outputs = conv2d(outputs, 256, 4, 1, "VALID")
            outputs = tf.reshape(outputs, [-1, 256])
            r = tf.dense(outputs, r_dim, nonlinearity=None, bn=False)
            return r
def cond_pixel_cnn(x,
                   gh=None,
                   sh=None,
                   nonlinearity=tf.nn.elu,
                   nr_resnet=5,
                   nr_filters=100,
                   nr_logistic_mix=10,
                   bn=False,
                   dropout_p=0.0,
                   kernel_initializer=None,
                   kernel_regularizer=None,
                   is_training=False,
                   counters={}):
    name = get_name("conv_pixel_cnn", counters)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       gh=gh,
                       sh=sh,
                       nonlinearity=nonlinearity,
                       dropout_p=dropout_p,
                       counters=counters):
            with arg_scope(
                [gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d],
                    bn=bn,
                    kernel_initializer=kernel_initializer,
                    kernel_regularizer=kernel_regularizer,
                    is_training=is_training):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    down_shift(
                        down_shifted_conv2d(x_pad,
                                            num_filters=nr_filters,
                                            filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
                receptive_field = (2, 3)
                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=down_right_shifted_conv2d))
                    receptive_field = (receptive_field[0] + 1,
                                       receptive_field[1] + 2)
                x_out = nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix)
                print("    * receptive_field", receptive_field)
                return x_out
Ejemplo n.º 15
0
def gaussian_sampler(loc, scale, counters={}):
    name = get_name("gaussian_sampler", counters)
    print("construct", name, "...")
    with tf.variable_scope(name):
        dist = tf.distributions.Normal(loc=0., scale=1.)
        z = dist.sample(sample_shape=int_shape(loc), seed=None)
        z = loc + tf.multiply(z, scale)
        print("    + gaussian_sampler", int_shape(z))
        return z
def context_encoder(contexts,
                    masks,
                    is_training,
                    nr_resnet=5,
                    nr_filters=100,
                    nonlinearity=None,
                    bn=False,
                    kernel_initializer=None,
                    kernel_regularizer=None,
                    counters={}):
    name = get_name("context_encoder", counters)
    print("construct", name, "...")
    x = contexts * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    if bn:
        print("*** Attention *** using bn in the context encoder\n")
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       nonlinearity=nonlinearity,
                       counters=counters):
            with arg_scope(
                [gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d],
                    bn=bn,
                    kernel_initializer=kernel_initializer,
                    kernel_regularizer=kernel_regularizer,
                    is_training=is_training):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
                receptive_field = (2, 3)
                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))
                    receptive_field = (receptive_field[0] + 1,
                                       receptive_field[1] + 2)
                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                print("    * receptive_field", receptive_field)
                return x_out
Ejemplo n.º 17
0
def mlp(X,
        scope="mlp",
        params=None,
        nonlinearity=None,
        bn=True,
        kernel_initializer=None,
        kernel_regularizer=None,
        is_training=False,
        counters={}):
    name = get_name(scope, counters)
    print("construct", name, "...")
    if params is not None:
        params.reverse()
    with tf.variable_scope(name):
        default_args = {
            "nonlinearity": nonlinearity,
            "bn": bn,
            "kernel_initializer": kernel_initializer,
            "kernel_regularizer": kernel_regularizer,
            "is_training": is_training,
            "counters": counters,
        }
        with arg_scope([dense], **default_args):
            batch_size = tf.shape(X)[0]
            size = 256
            outputs = X
            for k in range(4):
                if params is not None:
                    outputs = dense(outputs,
                                    size,
                                    W=params.pop(),
                                    b=params.pop())
                else:
                    outputs = dense(outputs, size)
            if params is not None:
                outputs = dense(outputs,
                                1,
                                nonlinearity=None,
                                W=params.pop(),
                                b=params.pop())
            else:
                outputs = dense(outputs, 1, nonlinearity=None)
            outputs = tf.reshape(outputs, shape=(batch_size, ))
            return outputs
Ejemplo n.º 18
0
def deconv2d_openai(x,
                    num_filters,
                    filter_size=[3, 3],
                    stride=[1, 1],
                    pad='SAME',
                    nonlinearity=None,
                    kernel_initializer=None,
                    init_scale=1.,
                    counters={},
                    init=False,
                    ema=None,
                    **kwargs):
    name = get_name("deconv2d", counters)
    xs = int_shape(x)
    if pad == 'SAME':
        target_shape = [
            xs[0], xs[1] * stride[0], xs[2] * stride[1], num_filters
        ]
    else:
        target_shape = [
            xs[0], xs[1] * stride[0] + filter_size[0] - 1,
            xs[2] * stride[1] + filter_size[1] - 1, num_filters
        ]
    with tf.variable_scope(name):
        V = tf.get_variable('V',
                            shape=filter_size +
                            [num_filters, int(x.get_shape()[-1])],
                            dtype=tf.float32,
                            initializer=kernel_initializer,
                            trainable=True)
        b = tf.get_variable('b',
                            shape=[num_filters],
                            dtype=tf.float32,
                            initializer=tf.constant_initializer(0.),
                            trainable=True)
        W = V
        x = tf.nn.conv2d_transpose(x,
                                   W,
                                   target_shape, [1] + stride + [1],
                                   padding=pad)
        x = tf.nn.bias_add(x, b)
        return x
Ejemplo n.º 19
0
def dense(inputs,
          num_outputs,
          W=None,
          b=None,
          nonlinearity=None,
          bn=False,
          kernel_initializer=None,
          kernel_regularizer=None,
          is_training=False,
          counters={}):
    ''' fully connected layer '''
    name = get_name('dense', counters)
    with tf.variable_scope(name):
        if W is None:
            W = tf.get_variable(
                'W',
                shape=[int(inputs.get_shape()[1]), num_outputs],
                dtype=tf.float32,
                trainable=True,
                initializer=kernel_initializer,
                regularizer=kernel_regularizer)
        if b is None:
            b = tf.get_variable('b',
                                shape=[num_outputs],
                                dtype=tf.float32,
                                trainable=True,
                                initializer=tf.constant_initializer(0.),
                                regularizer=None)

        outputs = tf.matmul(inputs, W) + tf.reshape(b, [1, num_outputs])

        if bn:
            outputs = tf.layers.batch_normalization(outputs,
                                                    training=is_training)
        if nonlinearity is not None:
            outputs = nonlinearity(outputs)
        print("    + dense", int_shape(inputs), int_shape(outputs),
              nonlinearity, bn)
        return outputs
def forward_pixel_cnn_32(x, context, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}):
    name = get_name("forward_pixel_cnn_32", counters)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet], gh=None, sh=None, nonlinearity=nonlinearity, dropout_p=dropout_p):
            with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on

                u_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], sh=context, conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], sh=context, conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                # /////// down pass ////////

                u = u_list.pop()
                ul = ul_list.pop()

                for rep in range(nr_resnet):
                    u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet+1):
                    u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2])


                for rep in range(nr_resnet+1):
                    u = gated_resnet(u, u_list.pop(), sh=None, conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), sh=None, conv=down_right_shifted_conv2d)

                x_out = nin(tf.nn.elu(ul),10*nr_logistic_mix)
                assert len(u_list) == 0
                assert len(ul_list) == 0
                return x_out