def forward_pixel_cnn_32(x, context, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}):
    name = get_name("forward_pixel_cnn_32", counters)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet], gh=None, sh=None, nonlinearity=nonlinearity, dropout_p=dropout_p):
            with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on

                u_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], sh=context, conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], sh=context, conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                # /////// down pass ////////

                u = u_list.pop()
                ul = ul_list.pop()

                for rep in range(nr_resnet):
                    u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet+1):
                    u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2])


                for rep in range(nr_resnet+1):
                    u = gated_resnet(u, u_list.pop(), sh=None, conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), sh=None, conv=down_right_shifted_conv2d)

                x_out = nin(tf.nn.elu(ul),10*nr_logistic_mix)
                assert len(u_list) == 0
                assert len(ul_list) == 0
                return x_out
    def _model(self, x, nr_resnet, nr_filters, nonlinearity, dropout_p, bn, kernel_initializer, kernel_regularizer, is_training):
        with arg_scope([gated_resnet], nonlinearity=nonlinearity, dropout_p=dropout_p, counters=self.counters):
            with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=self.counters):

                # ////////// up pass through pixelCNN ////////
                xs = int_shape(x)
                #ap = tf.Variable(np.zeros((xs[1], xs[2], 1), dtype=np.float32), trainable=True)
                #aps = tf.stack([ap for _ in range(xs[0])], axis=0)
                x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)

                u_list = [down_shift(down_shifted_conv2d(
                    x_pad, num_filters=nr_filters, filter_size=[2, 3]))]  # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1, 3])) +
                           right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 1]))]  # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(
                        u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(
                        ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(
                    u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(
                    ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(
                        u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(
                        ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(
                    u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(
                    ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(
                        u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(
                        ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                # /////// down pass ////////

                u = u_list.pop()
                ul = ul_list.pop()
                for rep in range(nr_resnet):
                    u = gated_resnet(
                        u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat(
                        [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(
                    u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(
                    ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet + 1):
                    u = gated_resnet(
                        u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat(
                        [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(
                    u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(
                    ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet + 1):
                    u = gated_resnet(
                        u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat(
                        [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d)

                x_out = nin(tf.nn.elu(ul), 1)

                assert len(u_list) == 0
                assert len(ul_list) == 0

                return x_out