Esempio n. 1
0
def forward_pixel_cnn_28_binary(x,
                                context,
                                nr_logistic_mix=10,
                                nr_resnet=1,
                                nr_filters=100,
                                dropout_p=0.0,
                                nonlinearity=None,
                                bn=True,
                                kernel_initializer=None,
                                kernel_regularizer=None,
                                is_training=False,
                                counters={}):
    name = get_name("forward_pixel_cnn_28_binary", counters)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       gh=None,
                       sh=context,
                       nonlinearity=nonlinearity,
                       dropout_p=dropout_p):
            with arg_scope([
                    gated_resnet, down_shifted_conv2d,
                    down_right_shifted_conv2d, down_shifted_deconv2d,
                    down_right_shifted_deconv2d
            ],
                           bn=bn,
                           kernel_initializer=kernel_initializer,
                           kernel_regularizer=kernel_regularizer,
                           is_training=is_training,
                           counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    down_shift(
                        down_shifted_conv2d(x_pad,
                                            num_filters=nr_filters,
                                            filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=down_right_shifted_conv2d))

                x_out = nin(tf.nn.elu(ul_list[-1]), 1)
                return x_out
Esempio n. 2
0
    def _model(self, x, nr_resnet, nr_filters, nonlinearity, dropout_p, bn, kernel_initializer, kernel_regularizer, is_training):
        with arg_scope([gated_resnet], nonlinearity=nonlinearity, dropout_p=dropout_p, counters=self.counters):
            with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=self.counters):

                # ////////// up pass through pixelCNN ////////
                xs = int_shape(x)
                #ap = tf.Variable(np.zeros((xs[1], xs[2], 1), dtype=np.float32), trainable=True)
                #aps = tf.stack([ap for _ in range(xs[0])], axis=0)
                x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)

                u_list = [down_shift(down_shifted_conv2d(
                    x_pad, num_filters=nr_filters, filter_size=[2, 3]))]  # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1, 3])) +
                           right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 1]))]  # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(
                        u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(
                        ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(
                    u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(
                    ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(
                        u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(
                        ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                u_list.append(down_shifted_conv2d(
                    u_list[-1], num_filters=nr_filters, strides=[2, 2]))
                ul_list.append(down_right_shifted_conv2d(
                    ul_list[-1], num_filters=nr_filters, strides=[2, 2]))

                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(
                        u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(
                        ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))

                # /////// down pass ////////

                u = u_list.pop()
                ul = ul_list.pop()
                for rep in range(nr_resnet):
                    u = gated_resnet(
                        u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat(
                        [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(
                    u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(
                    ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet + 1):
                    u = gated_resnet(
                        u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat(
                        [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d)

                u = down_shifted_deconv2d(
                    u, num_filters=nr_filters, strides=[2, 2])
                ul = down_right_shifted_deconv2d(
                    ul, num_filters=nr_filters, strides=[2, 2])

                for rep in range(nr_resnet + 1):
                    u = gated_resnet(
                        u, u_list.pop(), conv=down_shifted_conv2d)
                    ul = gated_resnet(ul, tf.concat(
                        [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d)

                x_out = nin(tf.nn.elu(ul), 1)

                assert len(u_list) == 0
                assert len(ul_list) == 0

                return x_out