Beispiel #1
0
def discriminator(inp, is_training, init=False, reuse=False, getter =None):
    with tf.variable_scope('discriminator_model', reuse=reuse,custom_getter=getter):
        counter = {}
        x = tf.reshape(inp, [-1, 32, 32, 3])

        x = tf.layers.dropout(x, rate=0.2, training=is_training, name='dropout_0')

        x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter)
        x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter)
        x = nn.conv2d(x, 96, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter)

        x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_1')

        x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter)
        x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter)
        x = nn.conv2d(x, 192, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter)

        x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_2')

        x = nn.conv2d(x, 192, pad='VALID', nonlinearity=leakyReLu, init=init, counters=counter)
        x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init)
        x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init)
        x = tf.layers.max_pooling2d(x, pool_size=6, strides=1,
                                    name='avg_pool_0')
        x = tf.squeeze(x, [1, 2])

        intermediate_layer = x

        logits = nn.dense(x, 10, nonlinearity=None, init=init, counters=counter, init_scale=0.1)

        return logits, intermediate_layer
Beispiel #2
0
def classifier(inp, is_training, init=False, reuse=False, getter =None,category=125):
    with tf.variable_scope('discriminator_model', reuse=reuse,custom_getter=getter):
        counter = {}
        #x = tf.reshape(inp, [-1, 32, 32, 3])
        x = tf.reshape(inp, [-1, 200, 30, 3])
        x = tf.layers.dropout(x, rate=0.2, training=is_training, name='dropout_0')

        x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter)                #  64*200*30*96
        x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter)                #  64*200*30*96
        #x = nn.conv2d(x, 96, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter) 
        x = nn.conv2d(x, 96, stride=[5, 2], nonlinearity=leakyReLu, init=init, counters=counter) #  64*40*15*96
        
        x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_1')               #  64*40*15*96

        x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter)               #  64*40*15*192
        x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter)               #  64*40*15*192
        #x = nn.conv2d(x, 192, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter)
        x = nn.conv2d(x, 192, stride=[5, 2], nonlinearity=leakyReLu, init=init, counters=counter)#  64*8*8*192

        x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_2')               #  64*8*8*192

        x = nn.conv2d(x, 192, pad='VALID', nonlinearity=leakyReLu, init=init, counters=counter)  #  64*6*6*192
        x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init)                  #  64*6*6*192
        x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init)                  #  64*6*6*192
        x = tf.layers.max_pooling2d(x, pool_size=6, strides=1, name='avg_pool_0')                #  64*1*1*192
        x = tf.squeeze(x, [1, 2])                                                                #  64*192

        intermediate_layer = x

        #logits = nn.dense(x, 10, nonlinearity=None, init=init, counters=counter, init_scale=0.1)
        logits = nn.dense(x, category, nonlinearity=None, init=init, counters=counter, init_scale=0.1) # 64*125
        print('logits:',logits)

        return logits, intermediate_layer
Beispiel #3
0
def dec_down(
        gs, zs_posterior, training, init=False, dropout_p=0.5,
        n_scales=1, n_residual_blocks=2, activation="elu",
        n_latent_scales=2):
    assert n_residual_blocks % 2 == 0
    gs = list(gs)
    zs_posterior = list(zs_posterior)
    with model_arg_scope(
            init=init, dropout_p=dropout_p, activation=activation):
        # outputs
        hs = []  # hidden units
        ps = []  # priors
        zs = []  # prior samples
        # prepare input
        n_filters = gs[-1].shape.as_list()[-1]
        h = nn.nin(gs[-1], n_filters)
        for l in range(n_scales):
            # level module
            ## hidden units
            for i in range(n_residual_blocks // 2):
                h = nn.residual_block(h, gs.pop())
                hs.append(h)
            if l < n_latent_scales:
                ## prior
                spatial_shape = h.shape.as_list()[1]
                n_h_channels = h.shape.as_list()[-1]

                ### no spatial correlations
                p = latent_parameters(h)
                ps.append(p)
                z_prior = latent_sample(p)
                zs.append(z_prior)

                if training:
                    ## posterior
                    z = zs_posterior.pop(0)
                else:
                    ## prior
                    z = z_prior
                for i in range(n_residual_blocks // 2):
                    n_h_channels = h.shape.as_list()[-1]
                    h = tf.concat([h, z], axis=-1)
                    h = nn.nin(h, n_h_channels)
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
            else:
                for i in range(n_residual_blocks // 2):
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = gs[-1].shape.as_list()[-1]
                h = nn.upsample(h, n_filters)

        assert not gs
        if training:
            assert not zs_posterior

        return hs, ps, zs
    def build(self, input_shape):
        B, H, W, C = input_shape
        self.normalize = normalize(name='norm')
        self.nin_q = nn.nin(name='q', num_units=C)
        self.nin_k = nn.nin(name='k', num_units=C)
        self.nin_v = nn.nin(name='v', num_units=C)

        self.nin_proj_out = nn.nin(name='proj_out', num_units=C, init_scale=0.)
Beispiel #5
0
def enc_up(x,
           c,
           init=False,
           dropout_p=0.5,
           n_scales=1,
           n_residual_blocks=2,
           activation="elu",
           n_filters=64,
           max_filters=128):
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        # outputs
        hs = []
        # prepare input
        # 这一行也很奇怪, 为什么要把x和c连起来呢?
        # xc = tf.concat([x,c], axis = -1)
        xc = x
        h = nn.nin(xc, n_filters)
        for l in range(n_scales):
            # level module
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                # 似乎它这个channel一直都是128, 没有增长过.
                n_filters = min(2 * n_filters, max_filters)
                h = nn.downsample(h, n_filters)

        return hs
    def build(self, input_shape):
        B, H, W, C = input_shape
        if self.out_ch is None:
            self.out_ch = C
        self.normalize_1 = normalize('norm1')
        self.normalize_2 = normalize('norm2')

        self.dense = nn.dense(name='temb_proj',
                              num_units=self.out_ch,
                              spec_norm=self.spec_norm)
        self.conv2d_1 = nn.conv2d(name='conv1',
                                  num_units=self.out_ch,
                                  spec_norm=self.spec_norm)

        self.conv2d_2 = nn.conv2d(name='conv2',
                                  num_units=self.out_ch,
                                  init_scale=0.,
                                  spec_norm=self.spec_norm,
                                  use_scale=self.use_scale)
        if self.conv_shortcut:
            self.conv2d_shortcut = nn.conv2d(name='conv_shortcut',
                                             num_units=self.out_ch,
                                             spec_norm=self.spec_norm)
        else:
            self.nin_shortcut = nn.nin(name='nin_shortcut',
                                       num_units=self.out_ch,
                                       spec_norm=self.spec_norm)
Beispiel #7
0
def dec_up(
    c,
    init=False,
    dropout_p=0.5,
    n_scales=1,
    n_residual_blocks=2,
    activation="elu",
    n_filters=64,
    max_filters=128,
):
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        # outputs
        hs = []
        # prepare input
        h = nn.nin(c, n_filters)
        for l in range(n_scales):
            # level module
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = min(2 * n_filters, max_filters)
                h = nn.downsample(h, n_filters)
        return hs
Beispiel #8
0
def enc_up(
    x,
    c,
    init=False,
    dropout_p=0.5,
    n_scales=1,
    n_residual_blocks=2,
    activation="elu",
    n_filters=64,
    max_filters=128,
):
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        """c is actually not used"""
        # outputs
        hs = []
        # prepare input
        # xc = tf.concat([x,c], axis = -1)
        xc = x
        h = nn.nin(xc, n_filters)
        for l in range(n_scales):
            # level module
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = min(2 * n_filters, max_filters)
                h = nn.downsample(h, n_filters)
        return hs
def classifier(
        x, n_out, init = False, dropout_p = 0.5,
        activation = "elu"):
    with model_arg_scope(
            init = init, dropout_p = dropout_p, activation = activation):
        # outputs
        hs = []
        # prepare input
        x_shape = x.shape.as_list()#tf.shape(x)
        h = tf.reshape(x, [x_shape[0], 1, 1, x_shape[1]*x_shape[2]*x_shape[3]])
        h = nn.activate(h)
        h = nn.nin(h, 1024)
        h = nn.activate(h)
        h = nn.nin(h, n_out)
        h = tf.reshape(h, [x_shape[0], n_out])
        return h
def cfn(
        x, init = False, dropout_p = 0.5,
        n_scales = 1, n_residual_blocks = 2, activation = "elu", n_filters = 64, max_filters = 128):
    with model_arg_scope(
            init = init, dropout_p = dropout_p, activation = activation):
        # outputs
        hs = []
        # prepare input
        xc = x
        h = nn.nin(xc, n_filters)
        for l in range(n_scales):
            # level module
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = min(2*n_filters, max_filters)
                h = nn.downsample(h, n_filters)
        h_shape = h.shape.as_list()
        h = tf.reshape(h, [h_shape[0],1,1,h_shape[1]*h_shape[2]*h_shape[3]])
        h = nn.nin(h, 2*max_filters)
        hs.append(h)
        return hs
def enc_down(
        gs, init = False, dropout_p = 0.5,
        n_scales = 1, n_residual_blocks = 2, activation = "elu",
        n_latent_scales = 2):
    assert n_residual_blocks % 2 == 0
    gs = list(gs)
    with model_arg_scope(
            init = init, dropout_p = dropout_p, activation = activation):
        # outputs
        hs = [] # hidden units
        qs = [] # posteriors
        zs = [] # samples from posterior
        # prepare input
        n_filters = gs[-1].shape.as_list()[-1]
        h = nn.nin(gs[-1], n_filters)
        for l in range(n_scales):
            # level module
            ## hidden units
            for i in range(n_residual_blocks // 2):
                h = nn.residual_block(h, gs.pop())
                hs.append(h)
            if l < n_latent_scales:
                ## posterior parameters
                q = latent_parameters(h)
                qs.append(q)
                ## posterior sample
                z = latent_sample(q)
                zs.append(z)
                ## sample feedback
                for i in range(n_residual_blocks // 2):
                    gz = tf.concat([gs.pop(), z], axis = -1)
                    h = nn.residual_block(h, gz)
                    hs.append(h)
            else:
                """ no need to go down any further
                for i in range(n_residual_blocks // 2):
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
                """
                break
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = gs[-1].shape.as_list()[-1]
                h = nn.upsample(h, n_filters)

        #assert not gs # not true anymore since we break out of the loop

        return hs, qs, zs
Beispiel #12
0
def dec_up(c,
           init=False,
           dropout_p=0.5,
           n_scales=1,
           n_residual_blocks=2,
           activation="elu",
           n_filters=64,
           max_filters=128):
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        hs = []
        h = nn.nin(c, n_filters)
        for l in range(n_scales):
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            if l + 1 < n_scales:
                n_filters = min(2 * n_filters, max_filters)
                h = nn.downsample(h, n_filters)
        return hs
Beispiel #13
0
def enc_down(gs,
             init=False,
             dropout_p=0.5,
             n_scales=1,
             n_residual_blocks=2,
             activation="elu",
             n_latent_scales=2):
    assert n_residual_blocks % 2 == 0
    gs = list(gs)
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        hs = []  # hidden units
        qs = []  # posteriors
        zs = []  # samples from posterior
        n_filters = gs[-1].shape.as_list()[-1]
        h = nn.nin(gs[-1], n_filters)
        for l in range(n_scales):
            for i in range(n_residual_blocks // 2):
                h = nn.residual_block(h, gs.pop())
                hs.append(h)
            if l < n_latent_scales:
                q = latent_parameters(h)  # posterior parameters
                qs.append(q)
                z = latent_sample(q)  # posterior sample
                zs.append(z)
                for i in range(n_residual_blocks // 2):
                    gz = tf.concat([gs.pop(), z], axis=-1)
                    h = nn.residual_block(h, gz)
                    hs.append(h)
            else:
                break
            if l + 1 < n_scales:
                n_filters = gs[-1].shape.as_list()[-1]
                h = nn.upsample(h, n_filters)

        return hs, qs, zs
Beispiel #14
0
def model_spec(x, init=False, ema=None, dropout_p=args.dropout_p):
    counters = {}
    with scopes.arg_scope([
            nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.aux_gated_resnet,
            nn.dense
    ],
                          counters=counters,
                          init=init,
                          ema=ema,
                          dropout_p=dropout_p):

        # ////////// up pass through pixelCNN ////////
        xs = nn.int_shape(x)
        x_pad = tf.concat(3, [
            x, tf.ones(xs[:-1] + [1])
        ])  # add channel of ones to distinguish image from padding later on
        u_list = [
            nn.down_shift(
                nn.down_shifted_conv2d(x_pad,
                                       num_filters=args.nr_filters,
                                       filter_size=[2, 3]))
        ]  # stream for pixels above
        ul_list = [nn.down_shift(nn.down_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 3])) + \
                   nn.right_shift(nn.down_right_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 1]))] # stream for up and to the left

        for rep in range(args.nr_resnet):
            u_list.append(
                nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
            ul_list.append(
                nn.aux_gated_resnet(ul_list[-1],
                                    u_list[-1],
                                    conv=nn.down_right_shifted_conv2d))

        u_list.append(
            nn.down_shifted_conv2d(u_list[-1],
                                   num_filters=args.nr_filters,
                                   stride=[2, 2]))
        ul_list.append(
            nn.down_right_shifted_conv2d(ul_list[-1],
                                         num_filters=args.nr_filters,
                                         stride=[2, 2]))

        for rep in range(args.nr_resnet):
            u_list.append(
                nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
            ul_list.append(
                nn.aux_gated_resnet(ul_list[-1],
                                    u_list[-1],
                                    conv=nn.down_right_shifted_conv2d))

        u_list.append(
            nn.down_shifted_conv2d(u_list[-1],
                                   num_filters=args.nr_filters,
                                   stride=[2, 2]))
        ul_list.append(
            nn.down_right_shifted_conv2d(ul_list[-1],
                                         num_filters=args.nr_filters,
                                         stride=[2, 2]))

        for rep in range(args.nr_resnet):
            u_list.append(
                nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
            ul_list.append(
                nn.aux_gated_resnet(ul_list[-1],
                                    u_list[-1],
                                    conv=nn.down_right_shifted_conv2d))

        # /////// down pass ////////
        u = u_list.pop()
        ul = ul_list.pop()
        for rep in range(args.nr_resnet):
            u = nn.aux_gated_resnet(u,
                                    u_list.pop(),
                                    conv=nn.down_shifted_conv2d)
            ul = nn.aux_gated_resnet(ul,
                                     tf.concat(3, [u, ul_list.pop()]),
                                     conv=nn.down_right_shifted_conv2d)

        u = nn.down_shifted_deconv2d(u,
                                     num_filters=args.nr_filters,
                                     stride=[2, 2])
        ul = nn.down_right_shifted_deconv2d(ul,
                                            num_filters=args.nr_filters,
                                            stride=[2, 2])

        for rep in range(args.nr_resnet + 1):
            u = nn.aux_gated_resnet(u,
                                    u_list.pop(),
                                    conv=nn.down_shifted_conv2d)
            ul = nn.aux_gated_resnet(ul,
                                     tf.concat(3, [u, ul_list.pop()]),
                                     conv=nn.down_right_shifted_conv2d)

        u = nn.down_shifted_deconv2d(u,
                                     num_filters=args.nr_filters,
                                     stride=[2, 2])
        ul = nn.down_right_shifted_deconv2d(ul,
                                            num_filters=args.nr_filters,
                                            stride=[2, 2])

        for rep in range(args.nr_resnet + 1):
            u = nn.aux_gated_resnet(u,
                                    u_list.pop(),
                                    conv=nn.down_shifted_conv2d)
            ul = nn.aux_gated_resnet(ul,
                                     tf.concat(3, [u, ul_list.pop()]),
                                     conv=nn.down_right_shifted_conv2d)

        x_out = nn.nin(tf.nn.elu(ul), 10 * args.nr_logistic_mix)

        assert len(u_list) == 0
        assert len(ul_list) == 0

        return x_out
Beispiel #15
0
def dec_down(gs,
             zs_posterior,
             training,
             init=False,
             dropout_p=0.5,
             n_scales=1,
             n_residual_blocks=2,
             activation="elu",
             n_latent_scales=2):
    assert n_residual_blocks % 2 == 0
    gs = list(gs)
    zs_posterior = list(zs_posterior)
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        # outputs
        hs = []  # hidden units
        ps = []  # priors
        zs = []  # prior samples
        # prepare input
        n_filters = gs[-1].shape.as_list()[-1]
        h = nn.nin(gs[-1], n_filters)
        for l in range(n_scales):
            # level module
            ## hidden units
            for i in range(n_residual_blocks // 2):
                h = nn.residual_block(h, gs.pop())
                hs.append(h)
            if l < n_latent_scales:
                ## prior
                spatial_shape = h.shape.as_list()[1]
                n_h_channels = h.shape.as_list()[-1]
                if spatial_shape == 1:
                    ### no spatial correlations
                    p = latent_parameters(h)
                    ps.append(p)
                    z_prior = latent_sample(p)
                    zs.append(z_prior)
                else:
                    ### four autoregressively modeled groups
                    if training:
                        z_posterior_groups = nn.split_groups(zs_posterior[0])
                    p_groups = []
                    z_groups = []
                    p_features = tf.space_to_depth(nn.residual_block(h), 2)
                    for i in range(4):
                        p_group = latent_parameters(p_features,
                                                    num_filters=n_h_channels)
                        p_groups.append(p_group)
                        z_group = latent_sample(p_group)
                        z_groups.append(z_group)
                        # ar feedback sampled from
                        if training:
                            feedback = z_posterior_groups.pop(0)
                        else:
                            feedback = z_group
                        # prepare input for next group
                        if i + 1 < 4:
                            p_features = nn.residual_block(
                                p_features, feedback)
                    if training:
                        assert not z_posterior_groups
                    # complete prior parameters
                    p = nn.merge_groups(p_groups)
                    ps.append(p)
                    # complete prior sample
                    z_prior = nn.merge_groups(z_groups)
                    zs.append(z_prior)
                ## vae feedback sampled from
                if training:
                    ## posterior
                    z = zs_posterior.pop(0)
                else:
                    ## prior
                    z = z_prior
                for i in range(n_residual_blocks // 2):
                    n_h_channels = h.shape.as_list()[-1]
                    h = tf.concat([h, z], axis=-1)
                    h = nn.nin(h, n_h_channels)
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
            else:
                for i in range(n_residual_blocks // 2):
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = gs[-1].shape.as_list()[-1]
                h = nn.upsample(h, n_filters)

        assert not gs
        if training:
            assert not zs_posterior

        return hs, ps, zs
Beispiel #16
0
def model_spec(x,
               h=None,
               init=False,
               ema=None,
               dropout_p=0.5,
               nr_resnet=5,
               nr_filters=160,
               nr_logistic_mix=10,
               resnet_nonlinearity='concat_elu',
               attention=False,
               nr_attn_block=1):
    """
    We receive a Tensor x of shape (N,H,W,D1) (e.g. (12,32,32,3)) and produce
    a Tensor x_out of shape (N,H,W,D2) (e.g. (12,32,32,100)), where each fiber
    of the x_out tensor describes the predictive distribution for the RGB at
    that position.
    'h' is an optional N x K matrix of values to condition our generative model on
    """

    counters = {}
    with arg_scope([nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.dense, nn.nin],
                   counters=counters,
                   init=init,
                   ema=ema,
                   dropout_p=dropout_p):

        # parse resnet nonlinearity argument
        if resnet_nonlinearity == 'concat_elu':
            resnet_nonlinearity = nn.concat_elu
        elif resnet_nonlinearity == 'elu':
            resnet_nonlinearity = tf.nn.elu
        elif resnet_nonlinearity == 'relu':
            resnet_nonlinearity = tf.nn.relu
        else:
            raise ('resnet nonlinearity ' + resnet_nonlinearity +
                   ' is not supported')

        with arg_scope([nn.gated_resnet],
                       nonlinearity=resnet_nonlinearity,
                       h=h):

            # ////////// up pass through pixelCNN ////////
            xs = nn.int_shape(x)
            background = tf.concat([
                ((tf.range(xs[1], dtype=tf.float32) - xs[1] / 2) /
                 xs[1])[None, :, None, None] + 0. * x,
                ((tf.range(xs[2], dtype=tf.float32) - xs[2] / 2) /
                 xs[2])[None, None, :, None] + 0. * x,
            ],
                                   axis=3)
            # add channel of ones to distinguish image from padding later on
            # stream for pixels above
            x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)
            u_list = [
                nn.down_shift(
                    nn.down_shifted_conv2d(x_pad,
                                           num_filters=nr_filters,
                                           filter_size=[2, 3]))
            ]
            # stream for up and to the left
            ul_list = [
                nn.down_shift(
                    nn.down_shifted_conv2d(
                        x_pad, num_filters=nr_filters, filter_size=[1, 3])) +
                nn.right_shift(
                    nn.down_right_shifted_conv2d(
                        x_pad, num_filters=nr_filters, filter_size=[2, 1]))
            ]

            for attn_rep in range(nr_attn_block):
                for rep in range(nr_resnet):
                    u_list.append(
                        nn.gated_resnet(u_list[-1],
                                        conv=nn.down_shifted_conv2d))
                    ul_list.append(
                        nn.gated_resnet(ul_list[-1],
                                        u_list[-1],
                                        conv=nn.down_right_shifted_conv2d))

                if attention:
                    ul = ul_list[-1]
                    raw_content = tf.concat([x, ul, background], axis=3)
                    key, mixin = tf.split(nn.nin(
                        nn.gated_resnet(raw_content, conv=nn.nin),
                        nr_filters * 2),
                                          2,
                                          axis=3)
                    query = nn.nin(
                        nn.gated_resnet(tf.concat([ul, background], axis=3),
                                        conv=nn.nin), nr_filters)
                    mixed = nn.causal_attention(key, mixin, query)

                    ul_list.append(nn.gated_resnet(ul, mixed, conv=nn.nin))

            x_out = nn.nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix)
            return x_out