Esempio n. 1
0
def reverse_pixel_cnn_28_binary(x,
                                masks,
                                context=None,
                                nr_logistic_mix=10,
                                nr_resnet=1,
                                nr_filters=100,
                                dropout_p=0.0,
                                nonlinearity=None,
                                bn=True,
                                kernel_initializer=None,
                                kernel_regularizer=None,
                                is_training=False,
                                counters={}):
    name = get_name("reverse_pixel_cnn_28_binary", counters)
    x = x * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       gh=None,
                       sh=context,
                       nonlinearity=nonlinearity,
                       dropout_p=dropout_p):
            with arg_scope([
                    gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d,
                    up_shifted_deconv2d, up_left_shifted_deconv2d
            ],
                           bn=bn,
                           kernel_initializer=kernel_initializer,
                           kernel_regularizer=kernel_regularizer,
                           is_training=is_training,
                           counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))

                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                return x_out
def cond_pixel_cnn(x,
                   gh=None,
                   sh=None,
                   nonlinearity=tf.nn.elu,
                   nr_resnet=5,
                   nr_filters=100,
                   nr_logistic_mix=10,
                   bn=False,
                   dropout_p=0.0,
                   kernel_initializer=None,
                   kernel_regularizer=None,
                   is_training=False,
                   counters={}):
    name = get_name("conv_pixel_cnn", counters)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       gh=gh,
                       sh=sh,
                       nonlinearity=nonlinearity,
                       dropout_p=dropout_p,
                       counters=counters):
            with arg_scope(
                [gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d],
                    bn=bn,
                    kernel_initializer=kernel_initializer,
                    kernel_regularizer=kernel_regularizer,
                    is_training=is_training):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    down_shift(
                        down_shifted_conv2d(x_pad,
                                            num_filters=nr_filters,
                                            filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
                receptive_field = (2, 3)
                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=down_right_shifted_conv2d))
                    receptive_field = (receptive_field[0] + 1,
                                       receptive_field[1] + 2)
                x_out = nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix)
                print("    * receptive_field", receptive_field)
                return x_out
def context_encoder(contexts,
                    masks,
                    is_training,
                    nr_resnet=5,
                    nr_filters=100,
                    nonlinearity=None,
                    bn=False,
                    kernel_initializer=None,
                    kernel_regularizer=None,
                    counters={}):
    name = get_name("context_encoder", counters)
    print("construct", name, "...")
    x = contexts * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    if bn:
        print("*** Attention *** using bn in the context encoder\n")
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       nonlinearity=nonlinearity,
                       counters=counters):
            with arg_scope(
                [gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d],
                    bn=bn,
                    kernel_initializer=kernel_initializer,
                    kernel_regularizer=kernel_regularizer,
                    is_training=is_training):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
                receptive_field = (2, 3)
                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))
                    receptive_field = (receptive_field[0] + 1,
                                       receptive_field[1] + 2)
                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                print("    * receptive_field", receptive_field)
                return x_out
def forward_pixel_cnn_32(x, context, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}):
    name = get_name("forward_pixel_cnn_32", counters)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet], gh=None, sh=None, nonlinearity=nonlinearity, dropout_p=dropout_p):
            with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on

                u_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], sh=context, conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], sh=context, conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                # /////// down pass ////////

                u = u_list.pop()
                ul = ul_list.pop()

                for rep in range(nr_resnet):
                    u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet+1):
                    u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2])


                for rep in range(nr_resnet+1):
                    u = gated_resnet(u, u_list.pop(), sh=None, conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), sh=None, conv=down_right_shifted_conv2d)

                x_out = nin(tf.nn.elu(ul),10*nr_logistic_mix)
                assert len(u_list) == 0
                assert len(ul_list) == 0
                return x_out
Esempio n. 5
0
    def _model(self, x, nr_resnet, nr_filters, nonlinearity, dropout_p, bn, kernel_initializer, kernel_regularizer, is_training):
        with arg_scope([gated_resnet], nonlinearity=nonlinearity, dropout_p=dropout_p, counters=self.counters):
            with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=self.counters):

                # ////////// up pass through pixelCNN ////////
                xs = int_shape(x)
                #ap = tf.Variable(np.zeros((xs[1], xs[2], 1), dtype=np.float32), trainable=True)
                #aps = tf.stack([ap for _ in range(xs[0])], axis=0)
                x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)

                u_list = [down_shift(down_shifted_conv2d(
                    x_pad, num_filters=nr_filters, filter_size=[2, 3]))]  # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1, 3])) +
                           right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 1]))]  # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(
                        u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(
                        ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(
                    u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(
                    ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(
                        u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(
                        ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(
                    u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(
                    ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(
                        u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(
                        ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                # /////// down pass ////////

                u = u_list.pop()
                ul = ul_list.pop()
                for rep in range(nr_resnet):
                    u = gated_resnet(
                        u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat(
                        [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(
                    u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(
                    ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet + 1):
                    u = gated_resnet(
                        u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat(
                        [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(
                    u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(
                    ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet + 1):
                    u = gated_resnet(
                        u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat(
                        [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d)

                x_out = nin(tf.nn.elu(ul), 1)

                assert len(u_list) == 0
                assert len(ul_list) == 0

                return x_out