def reverse_pixel_cnn_28_binary(x, masks, context=None, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("reverse_pixel_cnn_28_binary", counters) x = x * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) assert not bn, "auto-reggressive model should not use batch normalization" with tf.variable_scope(name): with arg_scope([gated_resnet], gh=None, sh=context, nonlinearity=nonlinearity, dropout_p=dropout_p): with arg_scope([ gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d, up_shifted_deconv2d, up_left_shifted_deconv2d ], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ up_shift( up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=up_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=up_left_shifted_conv2d)) x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters) return x_out
def context_encoder(contexts, masks, is_training, nr_resnet=5, nr_filters=100, nonlinearity=None, bn=False, kernel_initializer=None, kernel_regularizer=None, counters={}): name = get_name("context_encoder", counters) print("construct", name, "...") x = contexts * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) if bn: print("*** Attention *** using bn in the context encoder\n") with tf.variable_scope(name): with arg_scope([gated_resnet], nonlinearity=nonlinearity, counters=counters): with arg_scope( [gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ up_shift( up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left receptive_field = (2, 3) for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=up_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=up_left_shifted_conv2d)) receptive_field = (receptive_field[0] + 1, receptive_field[1] + 2) x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters) print(" * receptive_field", receptive_field) return x_out