def forward_pixel_cnn_28_binary(x, context, nr_logistic_mix=10, nr_resnet=1, nr_filters=100, dropout_p=0.0, nonlinearity=None, bn=True, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}): name = get_name("forward_pixel_cnn_28_binary", counters) print("construct", name, "...") print(" * nr_resnet: ", nr_resnet) print(" * nr_filters: ", nr_filters) print(" * nr_logistic_mix: ", nr_logistic_mix) assert not bn, "auto-reggressive model should not use batch normalization" with tf.variable_scope(name): with arg_scope([gated_resnet], gh=None, sh=context, nonlinearity=nonlinearity, dropout_p=dropout_p): with arg_scope([ gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d ], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=counters): xs = int_shape(x) x_pad = tf.concat( [x, tf.ones(xs[:-1] + [1])], 3 ) # add channel of ones to distinguish image from padding later on u_list = [ down_shift( down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left for rep in range(nr_resnet): u_list.append( gated_resnet(u_list[-1], conv=down_shifted_conv2d)) ul_list.append( gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) x_out = nin(tf.nn.elu(ul_list[-1]), 1) return x_out
def _model(self, x, nr_resnet, nr_filters, nonlinearity, dropout_p, bn, kernel_initializer, kernel_regularizer, is_training): with arg_scope([gated_resnet], nonlinearity=nonlinearity, dropout_p=dropout_p, counters=self.counters): with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=bn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training, counters=self.counters): # ////////// up pass through pixelCNN //////// xs = int_shape(x) #ap = tf.Variable(np.zeros((xs[1], xs[2], 1), dtype=np.float32), trainable=True) #aps = tf.stack([ap for _ in range(xs[0])], axis=0) x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3) u_list = [down_shift(down_shifted_conv2d( x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1, 3])) + right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 1]))] # stream for up and to the left for rep in range(nr_resnet): u_list.append(gated_resnet( u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet( ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) u_list.append(down_shifted_conv2d( u_list[-1], num_filters=nr_filters, strides=[2, 2])) ul_list.append(down_right_shifted_conv2d( ul_list[-1], num_filters=nr_filters, strides=[2, 2])) for rep in range(nr_resnet): u_list.append(gated_resnet( u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet( ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) u_list.append(down_shifted_conv2d( u_list[-1], num_filters=nr_filters, strides=[2, 2])) ul_list.append(down_right_shifted_conv2d( ul_list[-1], num_filters=nr_filters, strides=[2, 2])) for rep in range(nr_resnet): u_list.append(gated_resnet( u_list[-1], conv=down_shifted_conv2d)) ul_list.append(gated_resnet( ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d)) # /////// down pass //////// u = u_list.pop() ul = ul_list.pop() for rep in range(nr_resnet): u = gated_resnet( u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat( [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d) u = down_shifted_deconv2d( u, num_filters=nr_filters, strides=[2, 2]) ul = down_right_shifted_deconv2d( ul, num_filters=nr_filters, strides=[2, 2]) for rep in range(nr_resnet + 1): u = gated_resnet( u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat( [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d) u = down_shifted_deconv2d( u, num_filters=nr_filters, strides=[2, 2]) ul = down_right_shifted_deconv2d( ul, num_filters=nr_filters, strides=[2, 2]) for rep in range(nr_resnet + 1): u = gated_resnet( u, u_list.pop(), conv=down_shifted_conv2d) ul = gated_resnet(ul, tf.concat( [u, ul_list.pop()], 3), conv=down_right_shifted_conv2d) x_out = nin(tf.nn.elu(ul), 1) assert len(u_list) == 0 assert len(ul_list) == 0 return x_out