Example #1
0
def model_spec(x, init=False, ema=None, dropout_p=args.dropout_p):
    counters = {}
    with scopes.arg_scope([
            nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.aux_gated_resnet,
            nn.dense
    ],
                          counters=counters,
                          init=init,
                          ema=ema,
                          dropout_p=dropout_p):

        # ////////// up pass through pixelCNN ////////
        xs = nn.int_shape(x)
        x_pad = tf.concat(3, [
            x, tf.ones(xs[:-1] + [1])
        ])  # add channel of ones to distinguish image from padding later on
        u_list = [
            nn.down_shift(
                nn.down_shifted_conv2d(x_pad,
                                       num_filters=args.nr_filters,
                                       filter_size=[2, 3]))
        ]  # stream for pixels above
        ul_list = [nn.down_shift(nn.down_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 3])) + \
                   nn.right_shift(nn.down_right_shifted_conv2d(x_pad, num_filters=args.nr_filters, filter_size=[2, 1]))] # stream for up and to the left

        for rep in range(args.nr_resnet):
            u_list.append(
                nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
            ul_list.append(
                nn.aux_gated_resnet(ul_list[-1],
                                    u_list[-1],
                                    conv=nn.down_right_shifted_conv2d))

        u_list.append(
            nn.down_shifted_conv2d(u_list[-1],
                                   num_filters=args.nr_filters,
                                   stride=[2, 2]))
        ul_list.append(
            nn.down_right_shifted_conv2d(ul_list[-1],
                                         num_filters=args.nr_filters,
                                         stride=[2, 2]))

        for rep in range(args.nr_resnet):
            u_list.append(
                nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
            ul_list.append(
                nn.aux_gated_resnet(ul_list[-1],
                                    u_list[-1],
                                    conv=nn.down_right_shifted_conv2d))

        u_list.append(
            nn.down_shifted_conv2d(u_list[-1],
                                   num_filters=args.nr_filters,
                                   stride=[2, 2]))
        ul_list.append(
            nn.down_right_shifted_conv2d(ul_list[-1],
                                         num_filters=args.nr_filters,
                                         stride=[2, 2]))

        for rep in range(args.nr_resnet):
            u_list.append(
                nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
            ul_list.append(
                nn.aux_gated_resnet(ul_list[-1],
                                    u_list[-1],
                                    conv=nn.down_right_shifted_conv2d))

        # /////// down pass ////////
        u = u_list.pop()
        ul = ul_list.pop()
        for rep in range(args.nr_resnet):
            u = nn.aux_gated_resnet(u,
                                    u_list.pop(),
                                    conv=nn.down_shifted_conv2d)
            ul = nn.aux_gated_resnet(ul,
                                     tf.concat(3, [u, ul_list.pop()]),
                                     conv=nn.down_right_shifted_conv2d)

        u = nn.down_shifted_deconv2d(u,
                                     num_filters=args.nr_filters,
                                     stride=[2, 2])
        ul = nn.down_right_shifted_deconv2d(ul,
                                            num_filters=args.nr_filters,
                                            stride=[2, 2])

        for rep in range(args.nr_resnet + 1):
            u = nn.aux_gated_resnet(u,
                                    u_list.pop(),
                                    conv=nn.down_shifted_conv2d)
            ul = nn.aux_gated_resnet(ul,
                                     tf.concat(3, [u, ul_list.pop()]),
                                     conv=nn.down_right_shifted_conv2d)

        u = nn.down_shifted_deconv2d(u,
                                     num_filters=args.nr_filters,
                                     stride=[2, 2])
        ul = nn.down_right_shifted_deconv2d(ul,
                                            num_filters=args.nr_filters,
                                            stride=[2, 2])

        for rep in range(args.nr_resnet + 1):
            u = nn.aux_gated_resnet(u,
                                    u_list.pop(),
                                    conv=nn.down_shifted_conv2d)
            ul = nn.aux_gated_resnet(ul,
                                     tf.concat(3, [u, ul_list.pop()]),
                                     conv=nn.down_right_shifted_conv2d)

        x_out = nn.nin(tf.nn.elu(ul), 10 * args.nr_logistic_mix)

        assert len(u_list) == 0
        assert len(ul_list) == 0

        return x_out
Example #2
0
def EmbeddingImagenet(inp):
    """Returns the Imagenet-specific grayscale embedding for the given input."""
    with tf.name_scope("embedding"):
        channels_cond = 64
        leak = nn.conv(inp,
                       "conv_leak",
                       filter_size=(3, 3),
                       stride=1,
                       out_channels=channels_cond)
        with tf.name_scope("down_pass"):
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 1,
                                   a=None,
                                   conv=nn.conv)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 2,
                                   a=None,
                                   conv=nn.conv)
            channels_cond *= 2
            leak = nn.conv(leak,
                           "downscale_leak_1",
                           filter_size=(3, 3),
                           stride=2,
                           out_channels=channels_cond)

            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 3,
                                   a=None,
                                   conv=nn.conv)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 4,
                                   a=None,
                                   conv=nn.conv)
            channels_cond *= 2
            leak = nn.conv(leak,
                           "downscale_leak_2",
                           filter_size=(3, 3),
                           stride=2,
                           out_channels=channels_cond)

            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 5,
                                   a=None,
                                   conv=nn.conv)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 6,
                                   a=None,
                                   conv=nn.conv)
            channels_cond *= 2
            leak = nn.conv(leak,
                           "downscale_leak_3",
                           filter_size=(3, 3),
                           stride=1,
                           out_channels=channels_cond)

            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 7,
                                   a=None,
                                   conv=nn.conv,
                                   dilation=2)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 8,
                                   a=None,
                                   conv=nn.conv,
                                   dilation=2)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 9,
                                   a=None,
                                   conv=nn.conv,
                                   dilation=2)
            leak = nn.conv(leak,
                           "downscale_leak_4",
                           filter_size=(3, 3),
                           stride=1,
                           out_channels=channels_cond)

            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 10,
                                   a=None,
                                   conv=nn.conv,
                                   dilation=4)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 11,
                                   a=None,
                                   conv=nn.conv,
                                   dilation=4)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 12,
                                   a=None,
                                   conv=nn.conv,
                                   dilation=4)

            # Minor bug: wrong number of channels (TODO: retrian the model and fix the code)
            embedding = nn.conv(leak,
                                "downscale_leak_5",
                                filter_size=(3, 3),
                                stride=1,
                                out_channels=160)

    return embedding
Example #3
0
def PIColorization(x,
                   x_gray,
                   channels,
                   l,
                   num_outputs,
                   dataset,
                   return_embedding=False):
    """Define the auto-regressive network.
    Args:
      x: input
      x_gray: grayscale embedding
      channels: network width
      l (int): number of residual layers in the embedding network
      num_outputs (int): number of coeffs (ie logistic mixtures * n_coeffs per mixture)
      dataset (str): dataset
      return_embedding (bool, optional): if True, also return the embedding. Defaults to False
    """
    # PIC
    with tf.name_scope("pic"):

        with tf.name_scope("pad"):
            x_pad = tf.concat([x, tf.ones(nn.int_shape(x)[:-1] + [1])],
                              3,
                              name="x_pad")
            x_gray = tf.concat(
                [x_gray, tf.ones(nn.int_shape(x_gray)[:-1] + [1])],
                3,
                name="gray_pad")

        # Embedding
        assert (dataset in ['cifar', 'imagenet'])

        if dataset == 'cifar':
            embedding = EmbeddingCIFAR(x_gray)
        elif dataset == 'imagenet':
            embedding = EmbeddingImagenet(x_gray)

        # PixelCNN++
        with tf.name_scope("pcnn"):
            u = nn.down_shift(
                nn.down_shifted_conv2d(x_pad,
                                       "conv_down",
                                       filter_size=(2, 3),
                                       out_channels=channels))
            ul = nn.down_shift(nn.down_shifted_conv2d(x_pad, "conv_down_2",  filter_size=(1, 3), out_channels=channels)) + \
                nn.right_shift(nn.down_right_shifted_conv2d(x_pad, "conv_down_right", filter_size=(2, 1), out_channels=channels))

            for rep in range(l):
                u = nn.gated_resnet(u,
                                    "shortrange_down_%d" % rep,
                                    a=embedding,
                                    conv=nn.down_shifted_conv2d)
                ul = nn.gated_resnet(ul,
                                     "shortrange_down_right_%d" % rep,
                                     a=tf.concat([u, embedding], 3),
                                     conv=nn.down_right_shifted_conv2d)

        x_out = nn.conv(tf.nn.elu(ul), "conv_last", (1, 1), num_outputs)

    if return_embedding:
        return x_out, embedding
    else:
        return x_out
Example #4
0
def EmbeddingCIFAR(inp):
    """Returns the CIFAR-specific grayscale embedding for the given input."""
    with tf.name_scope("embedding"):
        channels_cond = 32
        leak = nn.conv(inp,
                       "conv_leak",
                       filter_size=(3, 3),
                       stride=1,
                       out_channels=channels_cond)
        with tf.name_scope("down_pass"):
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 1,
                                   a=None,
                                   conv=nn.conv)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 2,
                                   a=None,
                                   conv=nn.conv)
            channels_cond *= 2
            leak = nn.conv(leak,
                           "downscale_leak_1",
                           filter_size=(3, 3),
                           stride=2,
                           out_channels=channels_cond)

            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 3,
                                   a=None,
                                   conv=nn.conv)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 4,
                                   a=None,
                                   conv=nn.conv)
            channels_cond *= 2
            leak = nn.conv(leak,
                           "downscale_leak_2",
                           filter_size=(3, 3),
                           stride=1,
                           out_channels=channels_cond)

            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 5,
                                   a=None,
                                   conv=nn.conv)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 6,
                                   a=None,
                                   conv=nn.conv)
            channels_cond *= 2
            leak = nn.conv(leak,
                           "downscale_leak_3",
                           filter_size=(3, 3),
                           stride=1,
                           out_channels=channels_cond)

            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 7,
                                   a=None,
                                   conv=nn.conv,
                                   dilation=2)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 8,
                                   a=None,
                                   conv=nn.conv,
                                   dilation=2)
            leak = nn.gated_resnet(leak,
                                   "down_leak_%d" % 9,
                                   a=None,
                                   conv=nn.conv,
                                   dilation=2)
            embedding = nn.conv(leak,
                                "downscale_leak_4",
                                filter_size=(3, 3),
                                stride=1,
                                out_channels=channels_cond)

    return embedding
Example #5
0
def model_spec(x,
               h=None,
               init=False,
               ema=None,
               dropout_p=0.5,
               nr_resnet=5,
               nr_filters=160,
               nr_logistic_mix=10,
               resnet_nonlinearity='concat_elu',
               attention=False,
               nr_attn_block=1):
    """
    We receive a Tensor x of shape (N,H,W,D1) (e.g. (12,32,32,3)) and produce
    a Tensor x_out of shape (N,H,W,D2) (e.g. (12,32,32,100)), where each fiber
    of the x_out tensor describes the predictive distribution for the RGB at
    that position.
    'h' is an optional N x K matrix of values to condition our generative model on
    """

    counters = {}
    with arg_scope([nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.dense, nn.nin],
                   counters=counters,
                   init=init,
                   ema=ema,
                   dropout_p=dropout_p):

        # parse resnet nonlinearity argument
        if resnet_nonlinearity == 'concat_elu':
            resnet_nonlinearity = nn.concat_elu
        elif resnet_nonlinearity == 'elu':
            resnet_nonlinearity = tf.nn.elu
        elif resnet_nonlinearity == 'relu':
            resnet_nonlinearity = tf.nn.relu
        else:
            raise ('resnet nonlinearity ' + resnet_nonlinearity +
                   ' is not supported')

        with arg_scope([nn.gated_resnet],
                       nonlinearity=resnet_nonlinearity,
                       h=h):

            # ////////// up pass through pixelCNN ////////
            xs = nn.int_shape(x)
            background = tf.concat([
                ((tf.range(xs[1], dtype=tf.float32) - xs[1] / 2) /
                 xs[1])[None, :, None, None] + 0. * x,
                ((tf.range(xs[2], dtype=tf.float32) - xs[2] / 2) /
                 xs[2])[None, None, :, None] + 0. * x,
            ],
                                   axis=3)
            # add channel of ones to distinguish image from padding later on
            # stream for pixels above
            x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)
            u_list = [
                nn.down_shift(
                    nn.down_shifted_conv2d(x_pad,
                                           num_filters=nr_filters,
                                           filter_size=[2, 3]))
            ]
            # stream for up and to the left
            ul_list = [
                nn.down_shift(
                    nn.down_shifted_conv2d(
                        x_pad, num_filters=nr_filters, filter_size=[1, 3])) +
                nn.right_shift(
                    nn.down_right_shifted_conv2d(
                        x_pad, num_filters=nr_filters, filter_size=[2, 1]))
            ]

            for attn_rep in range(nr_attn_block):
                for rep in range(nr_resnet):
                    u_list.append(
                        nn.gated_resnet(u_list[-1],
                                        conv=nn.down_shifted_conv2d))
                    ul_list.append(
                        nn.gated_resnet(ul_list[-1],
                                        u_list[-1],
                                        conv=nn.down_right_shifted_conv2d))

                if attention:
                    ul = ul_list[-1]
                    raw_content = tf.concat([x, ul, background], axis=3)
                    key, mixin = tf.split(nn.nin(
                        nn.gated_resnet(raw_content, conv=nn.nin),
                        nr_filters * 2),
                                          2,
                                          axis=3)
                    query = nn.nin(
                        nn.gated_resnet(tf.concat([ul, background], axis=3),
                                        conv=nn.nin), nr_filters)
                    mixed = nn.causal_attention(key, mixin, query)

                    ul_list.append(nn.gated_resnet(ul, mixed, conv=nn.nin))

            x_out = nn.nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix)
            return x_out