Example #1
0
def reverse_pixel_cnn_28_binary(x,
                                masks,
                                context=None,
                                nr_logistic_mix=10,
                                nr_resnet=1,
                                nr_filters=100,
                                dropout_p=0.0,
                                nonlinearity=None,
                                bn=True,
                                kernel_initializer=None,
                                kernel_regularizer=None,
                                is_training=False,
                                counters={}):
    name = get_name("reverse_pixel_cnn_28_binary", counters)
    x = x * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    assert not bn, "auto-reggressive model should not use batch normalization"
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       gh=None,
                       sh=context,
                       nonlinearity=nonlinearity,
                       dropout_p=dropout_p):
            with arg_scope([
                    gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d,
                    up_shifted_deconv2d, up_left_shifted_deconv2d
            ],
                           bn=bn,
                           kernel_initializer=kernel_initializer,
                           kernel_regularizer=kernel_regularizer,
                           is_training=is_training,
                           counters=counters):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left

                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))

                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                return x_out
def context_encoder(contexts,
                    masks,
                    is_training,
                    nr_resnet=5,
                    nr_filters=100,
                    nonlinearity=None,
                    bn=False,
                    kernel_initializer=None,
                    kernel_regularizer=None,
                    counters={}):
    name = get_name("context_encoder", counters)
    print("construct", name, "...")
    x = contexts * broadcast_masks_tf(masks, num_channels=3)
    x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1)
    if bn:
        print("*** Attention *** using bn in the context encoder\n")
    with tf.variable_scope(name):
        with arg_scope([gated_resnet],
                       nonlinearity=nonlinearity,
                       counters=counters):
            with arg_scope(
                [gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d],
                    bn=bn,
                    kernel_initializer=kernel_initializer,
                    kernel_regularizer=kernel_regularizer,
                    is_training=is_training):
                xs = int_shape(x)
                x_pad = tf.concat(
                    [x, tf.ones(xs[:-1] + [1])], 3
                )  # add channel of ones to distinguish image from padding later on

                u_list = [
                    up_shift(
                        up_shifted_conv2d(x_pad,
                                          num_filters=nr_filters,
                                          filter_size=[2, 3]))
                ]  # stream for pixels above
                ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
                receptive_field = (2, 3)
                for rep in range(nr_resnet):
                    u_list.append(
                        gated_resnet(u_list[-1], conv=up_shifted_conv2d))
                    ul_list.append(
                        gated_resnet(ul_list[-1],
                                     u_list[-1],
                                     conv=up_left_shifted_conv2d))
                    receptive_field = (receptive_field[0] + 1,
                                       receptive_field[1] + 2)
                x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters)
                print("    * receptive_field", receptive_field)
                return x_out