Exemple #1
0
    def decoder(self, inputs, skips, nout):
        """VGG based image decoder.

    Args:
      inputs: image tensor with size BSxX
      skips: skip connections from encoder
      nout: number of output channels
    Returns:
      net: decoded image with size BSx64x64xNout
      skips: skip connection after each layer
    """
        vgg_layer = common_video.vgg_layer
        net = inputs
        # d1
        net = tfl.conv2d_transpose(net,
                                   512,
                                   kernel_size=4,
                                   padding="VALID",
                                   name="d1_deconv",
                                   activation=None)
        net = tfl.batch_normalization(net,
                                      training=self.is_training,
                                      name="d1_bn")
        net = tf.nn.leaky_relu(net)
        net = common_layers.upscale(net, 2)
        # d2
        net = tf.concat([net, skips[-1]], axis=3)
        net = tfcl.repeat(net, 2, vgg_layer, 512, scope="d2a")
        net = tfcl.repeat(net, 1, vgg_layer, 256, scope="d2b")
        net = common_layers.upscale(net, 2)
        # d3
        net = tf.concat([net, skips[-2]], axis=3)
        net = tfcl.repeat(net, 2, vgg_layer, 256, scope="d3a")
        net = tfcl.repeat(net, 1, vgg_layer, 128, scope="d3b")
        net = common_layers.upscale(net, 2)
        # d4
        net = tf.concat([net, skips[-3]], axis=3)
        net = tfcl.repeat(net, 1, vgg_layer, 128, scope="d4a")
        net = tfcl.repeat(net, 1, vgg_layer, 64, scope="d4b")
        net = common_layers.upscale(net, 2)
        # d5
        net = tf.concat([net, skips[-4]], axis=3)
        net = tfcl.repeat(net, 1, vgg_layer, 64, scope="d5")

        # if there are still skip connections left, we have more downscaling to do
        for i, s in enumerate(skips[-5::-1]):
            net = common_layers.upscale(net, 2)
            net = tf.concat([net, s], axis=3)
            net = tfcl.repeat(net, 1, vgg_layer, 64, scope="upscale%d" % i)

        net = tfl.conv2d_transpose(net,
                                   nout,
                                   kernel_size=3,
                                   padding="SAME",
                                   name="d6_deconv",
                                   activation=tf.sigmoid)
        return net
Exemple #2
0
    def generator(self, z, reuse=False):
        with tf.variable_scope("generator", reuse=reuse):
            ch = self.args.g_filters
            x = spectral_deconv2d(z, filters=ch, kernel_size=4, stride=1, is_training=self.is_training, padding='VALID',
                                  use_bias=False, scope='deconv2d')
            x = batch_norm(x, self.is_training, scope='batch_norm')
            x = tf.nn.relu(x)

            for i in range(self.layer_num // 2):
                if self.args.up_sample:
                    x = upscale(x, f=2)
                    x = spectral_conv2d(x, filters=ch // 2, kernel_size=3, stride=1, is_training=self.is_training,
                                        padding='SAME', scope='up_conv2d_' + str(i))
                else:
                    x = spectral_deconv2d(x, filters=ch // 2, kernel_size=4, stride=2, is_training=self.is_training,
                                          use_bias=False, scope='deconv2d_' + str(i))
                x = batch_norm(x, self.is_training, scope='batch_norm_' + str(i))
                x = tf.nn.relu(x)

                ch = ch // 2

            # Self Attention
            x = attention(x, ch, is_training=self.is_training, scope="attention", reuse=reuse)

            for i in range(self.layer_num // 2, self.layer_num):
                if self.args.up_sample:
                    x = upscale(x, f=2)
                    x = spectral_conv2d(x, filters=ch // 2, kernel_size=3, stride=1, is_training=self.is_training,
                                        padding='SAME', scope='up_conv2d_' + str(i))

                else:
                    x = spectral_deconv2d(x, filters=ch // 2, kernel_size=4, stride=2, is_training=self.is_training,
                                          use_bias=False, scope='deconv2d_' + str(i))
                x = batch_norm(x, self.is_training, scope='batch_norm_' + str(i))
                x = tf.nn.relu(x)

                ch = ch // 2

            if self.args.up_sample:
                x = upscale(x, f=2)
                x = spectral_conv2d(x, filters=self.args.img_size[2], kernel_size=3, stride=1, is_training=self.is_training,
                                    padding='SAME', scope='G_conv_logit')
            else:
                x = spectral_deconv2d(x, filters=self.args.img_size[2], kernel_size=4, stride=2, is_training=self.is_training,
                                      use_bias=False, scope='G_deconv_logit')
            x = tf.nn.tanh(x)

            return x
Exemple #3
0
  def decoder(self, inputs, skips, nout):
    """VGG based image decoder.

    Args:
      inputs: image tensor with size BSxX
      skips: skip connections from encoder
      nout: number of output channels
    Returns:
      net: decoded image with size BSx64x64xNout
      skips: skip connection after each layer
    """
    vgg_layer = self.vgg_layer
    net = inputs
    # d1
    net = slim.conv2d_transpose(net, 512, kernel_size=4, padding="VALID",
                                scope="d1_deconv", activation_fn=None)
    net = slim.batch_norm(net, scope="d1_bn")
    net = tf.nn.leaky_relu(net)
    net = common_layers.upscale(net, 2)
    # d2
    net = tf.concat([net, skips[3]], axis=3)
    net = slim.repeat(net, 2, vgg_layer, 512, scope="d2a")
    net = slim.repeat(net, 1, vgg_layer, 256, scope="d2b")
    net = common_layers.upscale(net, 2)
    # d3
    net = tf.concat([net, skips[2]], axis=3)
    net = slim.repeat(net, 2, vgg_layer, 256, scope="d3a")
    net = slim.repeat(net, 1, vgg_layer, 128, scope="d3b")
    net = common_layers.upscale(net, 2)
    # d4
    net = tf.concat([net, skips[1]], axis=3)
    net = slim.repeat(net, 1, vgg_layer, 128, scope="d4a")
    net = slim.repeat(net, 1, vgg_layer, 64, scope="d4b")
    net = common_layers.upscale(net, 2)
    # d5
    net = tf.concat([net, skips[0]], axis=3)
    net = slim.repeat(net, 1, vgg_layer, 64, scope="d5")
    net = slim.conv2d_transpose(net, nout, kernel_size=3, padding="SAME",
                                scope="d6_deconv", activation_fn=tf.sigmoid)
    return net
Exemple #4
0
  def decoder(self, inputs, skips, nout):
    """VGG based image decoder.

    Args:
      inputs: image tensor with size BSxX
      skips: skip connections from encoder
      nout: number of output channels
    Returns:
      net: decoded image with size BSx64x64xNout
      skips: skip connection after each layer
    """
    vgg_layer = self.vgg_layer
    net = inputs
    # d1
    net = slim.conv2d_transpose(net, 512, kernel_size=4, padding="VALID",
                                scope="d1_deconv", activation_fn=None)
    net = slim.batch_norm(net, scope="d1_bn")
    net = tf.nn.leaky_relu(net)
    net = common_layers.upscale(net, 2)
    # d2
    net = tf.concat([net, skips[3]], axis=3)
    net = slim.repeat(net, 2, vgg_layer, 512, scope="d2a")
    net = slim.repeat(net, 1, vgg_layer, 256, scope="d2b")
    net = common_layers.upscale(net, 2)
    # d3
    net = tf.concat([net, skips[2]], axis=3)
    net = slim.repeat(net, 2, vgg_layer, 256, scope="d3a")
    net = slim.repeat(net, 1, vgg_layer, 128, scope="d3b")
    net = common_layers.upscale(net, 2)
    # d4
    net = tf.concat([net, skips[1]], axis=3)
    net = slim.repeat(net, 1, vgg_layer, 128, scope="d4a")
    net = slim.repeat(net, 1, vgg_layer, 64, scope="d4b")
    net = common_layers.upscale(net, 2)
    # d5
    net = tf.concat([net, skips[0]], axis=3)
    net = slim.repeat(net, 1, vgg_layer, 64, scope="d5")
    net = slim.conv2d_transpose(net, nout, kernel_size=3, padding="SAME",
                                scope="d6_deconv", activation_fn=tf.sigmoid)
    return net
Exemple #5
0
    def decoder(self, inputs, nout, skips=None, has_batchnorm=True):
        """VGG based image decoder.

    Args:
      inputs: image tensor with size BSxX
      nout: number of output channels
      skips: optional skip connections from encoder
      has_batchnorm: variable to use or not use batch normalization
    Returns:
      net: decoded image with size BSx64x64xNout
      skips: skip connection after each layer
    """
        vgg_layer = common_video.vgg_layer
        net = inputs
        # d1
        net = tfl.conv2d_transpose(net,
                                   512,
                                   kernel_size=4,
                                   padding="VALID",
                                   name="d1_deconv",
                                   activation=tf.nn.relu)
        if has_batchnorm:
            net = tfl.batch_normalization(net,
                                          training=self.is_training,
                                          name="d1_bn")
        net = tf.nn.relu(net)
        net = common_layers.upscale(net, 2)
        # d2
        if skips is not None:
            net = tf.concat([net, skips[-1]], axis=3)
        net = tfcl.repeat(net,
                          2,
                          vgg_layer,
                          512,
                          scope="d2a",
                          is_training=self.is_training,
                          activation=tf.nn.relu,
                          has_batchnorm=has_batchnorm)
        net = tfcl.repeat(net,
                          1,
                          vgg_layer,
                          256,
                          scope="d2b",
                          is_training=self.is_training,
                          activation=tf.nn.relu,
                          has_batchnorm=has_batchnorm)
        net = common_layers.upscale(net, 2)
        # d3
        if skips is not None:
            net = tf.concat([net, skips[-2]], axis=3)
        net = tfcl.repeat(net,
                          2,
                          vgg_layer,
                          256,
                          scope="d3a",
                          is_training=self.is_training,
                          activation=tf.nn.relu,
                          has_batchnorm=has_batchnorm)
        net = tfcl.repeat(net,
                          1,
                          vgg_layer,
                          128,
                          scope="d3b",
                          is_training=self.is_training,
                          activation=tf.nn.relu,
                          has_batchnorm=has_batchnorm)
        net = common_layers.upscale(net, 2)
        # d4
        if skips is not None:
            net = tf.concat([net, skips[-3]], axis=3)
        net = tfcl.repeat(net,
                          1,
                          vgg_layer,
                          128,
                          scope="d4a",
                          is_training=self.is_training,
                          activation=tf.nn.relu,
                          has_batchnorm=has_batchnorm)
        net = tfcl.repeat(net,
                          1,
                          vgg_layer,
                          64,
                          scope="d4b",
                          is_training=self.is_training,
                          activation=tf.nn.relu,
                          has_batchnorm=has_batchnorm)
        net = common_layers.upscale(net, 2)
        # d5
        if skips is not None:
            net = tf.concat([net, skips[-4]], axis=3)
        net = tfcl.repeat(net,
                          1,
                          vgg_layer,
                          64,
                          scope="d5",
                          is_training=self.is_training,
                          activation=tf.nn.relu,
                          has_batchnorm=has_batchnorm)

        # if there are still skip connections left, we have more upscaling to do
        if skips is not None:
            for i, s in enumerate(skips[-5::-1]):
                net = common_layers.upscale(net, 2)
                net = tf.concat([net, s], axis=3)
                net = tfcl.repeat(net,
                                  1,
                                  vgg_layer,
                                  64,
                                  scope="upscale%d" % i,
                                  is_training=self.is_training,
                                  activation=tf.nn.relu,
                                  has_batchnorm=has_batchnorm)

        net = tfl.conv2d_transpose(net,
                                   nout,
                                   kernel_size=3,
                                   padding="SAME",
                                   name="d6_deconv",
                                   activation=None)
        return net
Exemple #6
0
  def decoder(self, inputs, nout, skips=None, has_batchnorm=True):
    """VGG based image decoder.

    Args:
      inputs: image tensor with size BSxX
      nout: number of output channels
      skips: optional skip connections from encoder
      has_batchnorm: variable to use or not use batch normalization
    Returns:
      net: decoded image with size BSx64x64xNout
      skips: skip connection after each layer
    """
    vgg_layer = common_video.vgg_layer
    net = inputs
    # d1
    net = tfl.conv2d_transpose(net, 512, kernel_size=4, padding="VALID",
                               name="d1_deconv", activation=tf.nn.relu)
    if has_batchnorm:
      net = tfl.batch_normalization(
          net, training=self.is_training, name="d1_bn")
    net = tf.nn.relu(net)
    net = common_layers.upscale(net, 2)
    # d2
    if skips is not None:
      net = tf.concat([net, skips[-1]], axis=3)
    net = tfcl.repeat(net, 2, vgg_layer, 512, scope="d2a",
                      is_training=self.is_training,
                      activation=tf.nn.relu, has_batchnorm=has_batchnorm)
    net = tfcl.repeat(net, 1, vgg_layer, 256, scope="d2b",
                      is_training=self.is_training,
                      activation=tf.nn.relu, has_batchnorm=has_batchnorm)
    net = common_layers.upscale(net, 2)
    # d3
    if skips is not None:
      net = tf.concat([net, skips[-2]], axis=3)
    net = tfcl.repeat(net, 2, vgg_layer, 256, scope="d3a",
                      is_training=self.is_training,
                      activation=tf.nn.relu, has_batchnorm=has_batchnorm)
    net = tfcl.repeat(net, 1, vgg_layer, 128, scope="d3b",
                      is_training=self.is_training,
                      activation=tf.nn.relu, has_batchnorm=has_batchnorm)
    net = common_layers.upscale(net, 2)
    # d4
    if skips is not None:
      net = tf.concat([net, skips[-3]], axis=3)
    net = tfcl.repeat(net, 1, vgg_layer, 128, scope="d4a",
                      is_training=self.is_training,
                      activation=tf.nn.relu, has_batchnorm=has_batchnorm)
    net = tfcl.repeat(net, 1, vgg_layer, 64, scope="d4b",
                      is_training=self.is_training,
                      activation=tf.nn.relu, has_batchnorm=has_batchnorm)
    net = common_layers.upscale(net, 2)
    # d5
    if skips is not None:
      net = tf.concat([net, skips[-4]], axis=3)
    net = tfcl.repeat(net, 1, vgg_layer, 64, scope="d5",
                      is_training=self.is_training,
                      activation=tf.nn.relu, has_batchnorm=has_batchnorm)

    # if there are still skip connections left, we have more upscaling to do
    if skips is not None:
      for i, s in enumerate(skips[-5::-1]):
        net = common_layers.upscale(net, 2)
        net = tf.concat([net, s], axis=3)
        net = tfcl.repeat(net, 1, vgg_layer, 64, scope="upscale%d" % i,
                          is_training=self.is_training,
                          activation=tf.nn.relu, has_batchnorm=has_batchnorm)

    net = tfl.conv2d_transpose(net, nout, kernel_size=3, padding="SAME",
                               name="d6_deconv", activation=None)
    return net