def model_spec(x, init=False, ema=None, dropout_p=args.dropout_p): counters = {} with scopes.arg_scope([ nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.aux_gated_resnet, nn.dense ], counters=counters, init=init, ema=ema, dropout_p=dropout_p): # ////////// up pass through pixelCNN //////// xs = nn.int_shape(x) x_pad = tf.concat(3, [ x, tf.ones(xs[:-1] + [1]) ]) # add channel of ones to distinguish image from padding later on u_list = [ nn.down_shift( nn.down_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 3])) ] # stream for pixels above ul_list = [nn.down_shift(nn.down_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 3])) + \ nn.right_shift(nn.down_right_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 1]))] # stream for up and to the left for rep in range(args.nr_resnet): u_list.append( nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) ul_list.append( nn.aux_gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) u_list.append( nn.down_shifted_conv2d(u_list[-1], num_filters=args.nr_filters, stride=[2, 2])) ul_list.append( nn.down_right_shifted_conv2d(ul_list[-1], num_filters=args.nr_filters, stride=[2, 2])) for rep in range(args.nr_resnet): u_list.append( nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) ul_list.append( nn.aux_gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) u_list.append( nn.down_shifted_conv2d(u_list[-1], num_filters=args.nr_filters, stride=[2, 2])) ul_list.append( nn.down_right_shifted_conv2d(ul_list[-1], num_filters=args.nr_filters, stride=[2, 2])) for rep in range(args.nr_resnet): u_list.append( nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) ul_list.append( nn.aux_gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) # /////// down pass //////// u = u_list.pop() ul = ul_list.pop() for rep in range(args.nr_resnet): u = nn.aux_gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d) ul = nn.aux_gated_resnet(ul, tf.concat(3, [u, ul_list.pop()]), conv=nn.down_right_shifted_conv2d) u = nn.down_shifted_deconv2d(u, num_filters=args.nr_filters, stride=[2, 2]) ul = nn.down_right_shifted_deconv2d(ul, num_filters=args.nr_filters, stride=[2, 2]) for rep in range(args.nr_resnet + 1): u = nn.aux_gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d) ul = nn.aux_gated_resnet(ul, tf.concat(3, [u, ul_list.pop()]), conv=nn.down_right_shifted_conv2d) u = nn.down_shifted_deconv2d(u, num_filters=args.nr_filters, stride=[2, 2]) ul = nn.down_right_shifted_deconv2d(ul, num_filters=args.nr_filters, stride=[2, 2]) for rep in range(args.nr_resnet + 1): u = nn.aux_gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d) ul = nn.aux_gated_resnet(ul, tf.concat(3, [u, ul_list.pop()]), conv=nn.down_right_shifted_conv2d) x_out = nn.nin(tf.nn.elu(ul), 10 * args.nr_logistic_mix) assert len(u_list) == 0 assert len(ul_list) == 0 return x_out
def EmbeddingImagenet(inp): """Returns the Imagenet-specific grayscale embedding for the given input.""" with tf.name_scope("embedding"): channels_cond = 64 leak = nn.conv(inp, "conv_leak", filter_size=(3, 3), stride=1, out_channels=channels_cond) with tf.name_scope("down_pass"): leak = nn.gated_resnet(leak, "down_leak_%d" % 1, a=None, conv=nn.conv) leak = nn.gated_resnet(leak, "down_leak_%d" % 2, a=None, conv=nn.conv) channels_cond *= 2 leak = nn.conv(leak, "downscale_leak_1", filter_size=(3, 3), stride=2, out_channels=channels_cond) leak = nn.gated_resnet(leak, "down_leak_%d" % 3, a=None, conv=nn.conv) leak = nn.gated_resnet(leak, "down_leak_%d" % 4, a=None, conv=nn.conv) channels_cond *= 2 leak = nn.conv(leak, "downscale_leak_2", filter_size=(3, 3), stride=2, out_channels=channels_cond) leak = nn.gated_resnet(leak, "down_leak_%d" % 5, a=None, conv=nn.conv) leak = nn.gated_resnet(leak, "down_leak_%d" % 6, a=None, conv=nn.conv) channels_cond *= 2 leak = nn.conv(leak, "downscale_leak_3", filter_size=(3, 3), stride=1, out_channels=channels_cond) leak = nn.gated_resnet(leak, "down_leak_%d" % 7, a=None, conv=nn.conv, dilation=2) leak = nn.gated_resnet(leak, "down_leak_%d" % 8, a=None, conv=nn.conv, dilation=2) leak = nn.gated_resnet(leak, "down_leak_%d" % 9, a=None, conv=nn.conv, dilation=2) leak = nn.conv(leak, "downscale_leak_4", filter_size=(3, 3), stride=1, out_channels=channels_cond) leak = nn.gated_resnet(leak, "down_leak_%d" % 10, a=None, conv=nn.conv, dilation=4) leak = nn.gated_resnet(leak, "down_leak_%d" % 11, a=None, conv=nn.conv, dilation=4) leak = nn.gated_resnet(leak, "down_leak_%d" % 12, a=None, conv=nn.conv, dilation=4) # Minor bug: wrong number of channels (TODO: retrian the model and fix the code) embedding = nn.conv(leak, "downscale_leak_5", filter_size=(3, 3), stride=1, out_channels=160) return embedding
def PIColorization(x, x_gray, channels, l, num_outputs, dataset, return_embedding=False): """Define the auto-regressive network. Args: x: input x_gray: grayscale embedding channels: network width l (int): number of residual layers in the embedding network num_outputs (int): number of coeffs (ie logistic mixtures * n_coeffs per mixture) dataset (str): dataset return_embedding (bool, optional): if True, also return the embedding. Defaults to False """ # PIC with tf.name_scope("pic"): with tf.name_scope("pad"): x_pad = tf.concat([x, tf.ones(nn.int_shape(x)[:-1] + [1])], 3, name="x_pad") x_gray = tf.concat( [x_gray, tf.ones(nn.int_shape(x_gray)[:-1] + [1])], 3, name="gray_pad") # Embedding assert (dataset in ['cifar', 'imagenet']) if dataset == 'cifar': embedding = EmbeddingCIFAR(x_gray) elif dataset == 'imagenet': embedding = EmbeddingImagenet(x_gray) # PixelCNN++ with tf.name_scope("pcnn"): u = nn.down_shift( nn.down_shifted_conv2d(x_pad, "conv_down", filter_size=(2, 3), out_channels=channels)) ul = nn.down_shift(nn.down_shifted_conv2d(x_pad, "conv_down_2", filter_size=(1, 3), out_channels=channels)) + \ nn.right_shift(nn.down_right_shifted_conv2d(x_pad, "conv_down_right", filter_size=(2, 1), out_channels=channels)) for rep in range(l): u = nn.gated_resnet(u, "shortrange_down_%d" % rep, a=embedding, conv=nn.down_shifted_conv2d) ul = nn.gated_resnet(ul, "shortrange_down_right_%d" % rep, a=tf.concat([u, embedding], 3), conv=nn.down_right_shifted_conv2d) x_out = nn.conv(tf.nn.elu(ul), "conv_last", (1, 1), num_outputs) if return_embedding: return x_out, embedding else: return x_out
def EmbeddingCIFAR(inp): """Returns the CIFAR-specific grayscale embedding for the given input.""" with tf.name_scope("embedding"): channels_cond = 32 leak = nn.conv(inp, "conv_leak", filter_size=(3, 3), stride=1, out_channels=channels_cond) with tf.name_scope("down_pass"): leak = nn.gated_resnet(leak, "down_leak_%d" % 1, a=None, conv=nn.conv) leak = nn.gated_resnet(leak, "down_leak_%d" % 2, a=None, conv=nn.conv) channels_cond *= 2 leak = nn.conv(leak, "downscale_leak_1", filter_size=(3, 3), stride=2, out_channels=channels_cond) leak = nn.gated_resnet(leak, "down_leak_%d" % 3, a=None, conv=nn.conv) leak = nn.gated_resnet(leak, "down_leak_%d" % 4, a=None, conv=nn.conv) channels_cond *= 2 leak = nn.conv(leak, "downscale_leak_2", filter_size=(3, 3), stride=1, out_channels=channels_cond) leak = nn.gated_resnet(leak, "down_leak_%d" % 5, a=None, conv=nn.conv) leak = nn.gated_resnet(leak, "down_leak_%d" % 6, a=None, conv=nn.conv) channels_cond *= 2 leak = nn.conv(leak, "downscale_leak_3", filter_size=(3, 3), stride=1, out_channels=channels_cond) leak = nn.gated_resnet(leak, "down_leak_%d" % 7, a=None, conv=nn.conv, dilation=2) leak = nn.gated_resnet(leak, "down_leak_%d" % 8, a=None, conv=nn.conv, dilation=2) leak = nn.gated_resnet(leak, "down_leak_%d" % 9, a=None, conv=nn.conv, dilation=2) embedding = nn.conv(leak, "downscale_leak_4", filter_size=(3, 3), stride=1, out_channels=channels_cond) return embedding
def model_spec(x, h=None, init=False, ema=None, dropout_p=0.5, nr_resnet=5, nr_filters=160, nr_logistic_mix=10, resnet_nonlinearity='concat_elu', attention=False, nr_attn_block=1): """ We receive a Tensor x of shape (N,H,W,D1) (e.g. (12,32,32,3)) and produce a Tensor x_out of shape (N,H,W,D2) (e.g. (12,32,32,100)), where each fiber of the x_out tensor describes the predictive distribution for the RGB at that position. 'h' is an optional N x K matrix of values to condition our generative model on """ counters = {} with arg_scope([nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.dense, nn.nin], counters=counters, init=init, ema=ema, dropout_p=dropout_p): # parse resnet nonlinearity argument if resnet_nonlinearity == 'concat_elu': resnet_nonlinearity = nn.concat_elu elif resnet_nonlinearity == 'elu': resnet_nonlinearity = tf.nn.elu elif resnet_nonlinearity == 'relu': resnet_nonlinearity = tf.nn.relu else: raise ('resnet nonlinearity ' + resnet_nonlinearity + ' is not supported') with arg_scope([nn.gated_resnet], nonlinearity=resnet_nonlinearity, h=h): # ////////// up pass through pixelCNN //////// xs = nn.int_shape(x) background = tf.concat([ ((tf.range(xs[1], dtype=tf.float32) - xs[1] / 2) / xs[1])[None, :, None, None] + 0. * x, ((tf.range(xs[2], dtype=tf.float32) - xs[2] / 2) / xs[2])[None, None, :, None] + 0. * x, ], axis=3) # add channel of ones to distinguish image from padding later on # stream for pixels above x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3) u_list = [ nn.down_shift( nn.down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3])) ] # stream for up and to the left ul_list = [ nn.down_shift( nn.down_shifted_conv2d( x_pad, num_filters=nr_filters, filter_size=[1, 3])) + nn.right_shift( nn.down_right_shifted_conv2d( x_pad, num_filters=nr_filters, filter_size=[2, 1])) ] for attn_rep in range(nr_attn_block): for rep in range(nr_resnet): u_list.append( nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) ul_list.append( nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) if attention: ul = ul_list[-1] raw_content = tf.concat([x, ul, background], axis=3) key, mixin = tf.split(nn.nin( nn.gated_resnet(raw_content, conv=nn.nin), nr_filters * 2), 2, axis=3) query = nn.nin( nn.gated_resnet(tf.concat([ul, background], axis=3), conv=nn.nin), nr_filters) mixed = nn.causal_attention(key, mixin, query) ul_list.append(nn.gated_resnet(ul, mixed, conv=nn.nin)) x_out = nn.nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix) return x_out