def reverse_pixel_cnn_28_binary(x, masks, context=None, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("reverse_pixel_cnn_28_binary", counters) x = x * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) assert not bn, "auto-reggressive model should not use batch normalization" with tf.variable_scope(name): with arg_scope([gated_resnet], gh=None, sh=context, nonlinearity=nonlinearity, dropout_p=dropout_p): with arg_scope([ gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d, up_shifted_deconv2d, up_left_shifted_deconv2d ], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ up_shift( up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=up_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=up_left_shifted_conv2d)) x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters) return x_out
def cond_pixel_cnn(x, gh=None, sh=None, nonlinearity=tf.nn.elu, nr_resnet=5, nr_filters=100, nr_logistic_mix=10, bn=False, dropout_p=0.0, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("conv_pixel_cnn", counters) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) assert not bn, "auto-reggressive model should not use batch normalization" with tf.variable_scope(name): with arg_scope([gated_resnet], gh=gh, sh=sh, nonlinearity=nonlinearity, dropout_p=dropout_p, counters=counters): with arg_scope( [gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ down_shift( down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left receptive_field = (2, 3) for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=down_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) receptive_field = (receptive_field[0] + 1, receptive_field[1] + 2) x_out = nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix) print(" * receptive_field", receptive_field) return x_out
def context_encoder(contexts, masks, is_training, nr_resnet=5, nr_filters=100, nonlinearity=None, bn=False, kernel_initializer=None, kernel_regularizer=None, counters={}): name = get_name("context_encoder", counters) print("construct", name, "...") x = contexts * broadcast_masks_tf(masks, num_channels=3) x = tf.concat([x, broadcast_masks_tf(masks, num_channels=1)], axis=-1) if bn: print("*** Attention *** using bn in the context encoder\n") with tf.variable_scope(name): with arg_scope([gated_resnet], nonlinearity=nonlinearity, counters=counters): with arg_scope( [gated_resnet, up_shifted_conv2d, up_left_shifted_conv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ up_shift( up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [up_shift(up_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ left_shift(up_left_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left receptive_field = (2, 3) for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=up_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=up_left_shifted_conv2d)) receptive_field = (receptive_field[0] + 1, receptive_field[1] + 2) x_out = nin(tf.nn.elu(ul_list[-1]), nr_filters) print(" * receptive_field", receptive_field) return x_out
def forward_pixel_cnn_32(x, context, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("forward_pixel_cnn_32", counters) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) assert not bn, "auto-reggressive model should not use batch normalization" with tf.variable_scope(name): with arg_scope([gated_resnet], gh=None, sh=None, nonlinearity=nonlinearity, dropout_p=dropout_p): with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters): xs = int_shape(x) x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on u_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left for rep in range(nr_resnet): u_list.append(gated_resnet(u_list[-1], sh=context, conv=down_shifted_conv2d)) ul_list.append(gated_resnet(ul_list[-1], u_list[-1], sh=context, conv=down_right_shifted_conv2d)) u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2])) ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2])) for rep in range(nr_resnet): u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) u_list.append(down_shifted_conv2d(u_list[-1], num_filters=nr_filters, strides=[2, 2])) ul_list.append(down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, strides=[2, 2])) for rep in range(nr_resnet): u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) # /////// down pass //////// u = u_list.pop() ul = ul_list.pop() for rep in range(nr_resnet): u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d) u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2]) ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2]) for rep in range(nr_resnet+1): u = gated_resnet(u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=down_right_shifted_conv2d) u = down_shifted_deconv2d(u, num_filters=nr_filters, strides=[2, 2]) ul = down_right_shifted_deconv2d(ul, num_filters=nr_filters, strides=[2, 2]) for rep in range(nr_resnet+1): u = gated_resnet(u, u_list.pop(), sh=None, conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat([u, ul_list.pop()],3), sh=None, conv=down_right_shifted_conv2d) x_out = nin(tf.nn.elu(ul),10*nr_logistic_mix) assert len(u_list) == 0 assert len(ul_list) == 0 return x_out
def _model(self, x, nr_resnet, nr_filters, nonlinearity, dropout_p, bn, kernel_initializer, kernel_regularizer, is_training): with arg_scope([gated_resnet], nonlinearity=nonlinearity, dropout_p=dropout_p, counters=self.counters): with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=self.counters): # ////////// up pass through pixelCNN //////// xs = int_shape(x) #ap = tf.Variable(np.zeros((xs[1], xs[2], 1), dtype=np.float32), trainable=True) #aps = tf.stack([ap for _ in range(xs[0])], axis=0) x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3) u_list = [down_shift(down_shifted_conv2d( x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1, 3])) + right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 1]))] # stream for up and to the left for rep in range(nr_resnet): u_list.append(gated_resnet( u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet( ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) u_list.append(down_shifted_conv2d( u_list[-1], num_filters=nr_filters, strides=[2, 2])) ul_list.append(down_right_shifted_conv2d( ul_list[-1], num_filters=nr_filters, strides=[2, 2])) for rep in range(nr_resnet): u_list.append(gated_resnet( u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet( ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) u_list.append(down_shifted_conv2d( u_list[-1], num_filters=nr_filters, strides=[2, 2])) ul_list.append(down_right_shifted_conv2d( ul_list[-1], num_filters=nr_filters, strides=[2, 2])) for rep in range(nr_resnet): u_list.append(gated_resnet( u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet( ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) # /////// down pass //////// u = u_list.pop() ul = ul_list.pop() for rep in range(nr_resnet): u = gated_resnet( u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat( [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d) u = down_shifted_deconv2d( u, num_filters=nr_filters, strides=[2, 2]) ul = down_right_shifted_deconv2d( ul, num_filters=nr_filters, strides=[2, 2]) for rep in range(nr_resnet + 1): u = gated_resnet( u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat( [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d) u = down_shifted_deconv2d( u, num_filters=nr_filters, strides=[2, 2]) ul = down_right_shifted_deconv2d( ul, num_filters=nr_filters, strides=[2, 2]) for rep in range(nr_resnet + 1): u = gated_resnet( u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat( [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d) x_out = nin(tf.nn.elu(ul), 1) assert len(u_list) == 0 assert len(ul_list) == 0 return x_out