Exemple #1
0
def first_block(x,
                target_size,
                noise_dim,
                upsampling='deconv',
                normalization='batch',
                is_training=True):
    if upsampling == 'deconv':
        _x = reshape(x, (1, 1, noise_dim))
        _x = conv2d_transpose(_x,
                              1024,
                              target_size,
                              strides=(1, 1),
                              padding='valid')
    elif upsampling == 'dense':
        _x = dense(x, target_size[0] * target_size[1] * 1024)
        _x = reshape(_x, (target_size[1], target_size[0], 1024))
    else:
        raise ValueError

    if normalization == 'batch':
        _x = batch_norm(_x, is_training=is_training)
    elif normalization == 'layer':
        _x = layer_norm(_x, is_training=is_training)
    elif normalization is None:
        pass
    else:
        raise ValueError
    _x = activation(_x, 'relu')
    return _x
    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            for idx, (f, k, s, p) in enumerate(zip(self.list_filters, self.list_kernel_size, self.list_strides, self.list_padding)):
                if idx == 0:
                    bn = False
                else:
                    bn = True
                name = "conv2D_%s" % idx
                x = layers.conv2d_block(name, x, f, k, s, p=p, stddev=0.02,
                                        data_format=self.data_format, bias=True, bn=bn, activation_fn=layers.lrelu)

            target_shape = (self.batch_size, -1)
            x = layers.reshape(x, target_shape)

            # # Add MBD
            # x_mbd = layers.mini_batch_disc(x, num_kernels=100, dim_per_kernel=5)
            # # Concat
            # x = tf.concat([x, x_mbd], axis=1)

            x = layers.linear(x, 1)

            return x
Exemple #3
0
        def __call__(self, x, reuse=False):
            with tf.variable_scope(self.name) as scope:

                if reuse:
                    scope.reuse_variables()

                M, N = x.get_shape().as_list()[-2:]
                x = scattering.Scattering(M=M, N=N, J=2)(x)
                x = tf.contrib.layers.batch_norm(x,
                                                 data_format=FLAGS.data_format,
                                                 fused=True,
                                                 scope="scat_bn")
                x = layers.conv2d_block("CONV2D",
                                        x,
                                        64,
                                        1,
                                        1,
                                        p="SAME",
                                        data_format=FLAGS.data_format,
                                        bias=True,
                                        bn=False,
                                        activation_fn=tf.nn.relu)

                target_shape = (-1, 64 * 7 * 7)
                x = layers.reshape(x, target_shape)
                x = layers.linear(x, 512, name="dense1")
                x = tf.nn.relu(x)
                x = layers.linear(x, 10, name="dense2")

                return x
    def __call__(self, x, reuse=False, output_name=None):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            # Initial dense multiplication
            x = layers.linear(x, "G_FC1", 512 * 8 * 8)

            batch_size = tf.shape(x)[0]
            if FLAGS.data_format == "NHWC":
                target_shape = (batch_size, 8, 8, 512)
            elif FLAGS.data_format == "NCHW":
                target_shape = (batch_size, 512, 8, 8)

            x = layers.reshape(x, target_shape)
            x = tf.contrib.layers.batch_norm(x, fused=True, data_format=FLAGS.data_format)
            x = layers.lrelu(x)

            x = layers.G_conv2d_block(x, "G_conv2D1", 256, 3, data_format=FLAGS.data_format, bn=True)
            x = layers.upsampleNN(x, "G_up1", 2, data_format=FLAGS.data_format)

            x = layers.G_conv2d_block(x, "G_conv2D2", 128, 3, data_format=FLAGS.data_format, bn=True)
            x = layers.upsampleNN(x, "G_up2", 2, data_format=FLAGS.data_format)

            x = layers.G_conv2d_block(x, "G_conv2D3", 64, 3, data_format=FLAGS.data_format, bn=True)
            x = layers.upsampleNN(x, "G_up3", 2, data_format=FLAGS.data_format)

            # Last conv
            x = layers.conv2d(x, "G_conv2D4", 64, FLAGS.channels, 3, 1, "SAME", data_format=FLAGS.data_format)

            x = tf.nn.tanh(x, name=output_name)

            return x
Exemple #5
0
    def __split_heads(x, num_heads):
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions.

        Args:
            x(Tensor): a 3-D input Tensor.
            num_heads(int): The number of heads.

        Returns:
            Tensor: a Tensor with shape [..., n, m/num_heads], where m is size
                    of the last dimension of x.
        """
        if num_heads == 1:
            return x

        hidden_size = x.shape[-1]
        # reshape the 3-D input: [batch_size, max_sequence_length, hidden_dim]
        # into a 4-D output:
        # [batch_size, max_sequence_length, num_heads, hidden_size_per_head].
        reshaped = layers.reshape(x=x,
                                  shape=list(x.shape[:-1]) +
                                  [num_heads, hidden_size // num_heads])

        # permuate the dimensions into:
        # [batch_size, num_heads, max_sequence_len, hidden_size_per_head]
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as vs:
            if reuse:
                vs.reuse_variables()

            with tf.variable_scope('Encoder'):
                _x = conv_block(x,
                                filters=16,
                                sampling='same',
                                **self.conv_block_params)
                _x = conv_block(_x,
                                filters=16,
                                sampling='down',
                                **self.conv_block_params)

                _x = conv_block(_x,
                                filters=32,
                                sampling='same',
                                **self.conv_block_params)
                _x = conv_block(_x,
                                filters=32,
                                sampling='down',
                                **self.conv_block_params)

                current_shape = _x.get_shape().as_list()[1:]
                _x = flatten(_x)
                _x = dense(_x, 512, activation_='lrelu')
                encoded = dense(_x, self.latent_dim)

            with tf.variable_scope('Decoder'):
                _x = dense(encoded, 512, activation_='lrelu')
                _x = dense(_x,
                           current_shape[0] * current_shape[1] *
                           current_shape[2],
                           activation_='lrelu')
                _x = reshape(_x, current_shape)

                _x = conv_block(_x,
                                filters=32,
                                sampling=self.upsampling,
                                **self.conv_block_params)
                _x = conv_block(_x,
                                filters=16,
                                sampling='same',
                                **self.conv_block_params)

                _x = conv_block(_x,
                                filters=16,
                                sampling=self.upsampling,
                                **self.conv_block_params)
                _x = conv_block(_x,
                                filters=self.channel,
                                sampling='same',
                                **self.last_conv_block_params)

            return encoded, _x
Exemple #7
0
    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            #################
            # Generator
            #################

            # Initial dense multiplication
            x = layers.linear(x, 512 * 8 * 8)

            # Reshape to image format
            if FLAGS.data_format == "NCHW":
                target_shape = (-1, 512, 8, 8)
            else:
                target_shape = (-1, 8, 8, 512)

            x = layers.reshape(x, target_shape)
            x = tf.contrib.layers.batch_norm(x, fused=True)
            x = tf.nn.elu(x)

            # Conv2D + Phase shift blocks
            x = layers.conv2d_block(x,
                                    "G16_conv2D_1",
                                    256,
                                    3,
                                    1,
                                    data_format=FLAGS.data_format)
            x = layers.conv2d_block(x,
                                    "G16_conv2D_2",
                                    256,
                                    3,
                                    1,
                                    data_format=FLAGS.data_format)
            x = layers.phase_shift(x,
                                   upsampling_factor=2,
                                   name="PS_G16",
                                   data_format=FLAGS.data_format)
            x = layers.conv2d_block(x,
                                    "G16_conv2D_3",
                                    FLAGS.channels,
                                    3,
                                    1,
                                    bn=False,
                                    activation_fn=None,
                                    data_format=FLAGS.data_format)
            x = tf.nn.tanh(x, name="x_G16")

            return x
Exemple #8
0
    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            # Initial dense multiplication
            x = layers.linear(x,
                              self.filters * self.start_dim * self.start_dim,
                              bias=True)

            # Reshape to image format
            if self.data_format == "NCHW":
                target_shape = (self.batch_size, self.filters, self.start_dim,
                                self.start_dim)
            else:
                target_shape = (self.batch_size, self.start_dim,
                                self.start_dim, self.filters)

            x = layers.reshape(x, target_shape)
            x = tf.contrib.layers.batch_norm(x, fused=True)
            x = layers.lrelu(x)

            # # Upsampling2D + conv blocks
            for idx, (f, k, s, p) in enumerate(
                    zip(self.list_filters, self.list_kernel_size,
                        self.list_strides, self.list_padding)):
                name = "upsample2D_%s" % idx
                if idx == len(self.list_filters) - 1:
                    bn = False
                    bias = False
                    activation_fn = None
                else:
                    bias = True
                    bn = True
                    activation_fn = layers.lrelu
                x = layers.upsample2d_block(name,
                                            x,
                                            f,
                                            k,
                                            s,
                                            p,
                                            data_format=self.data_format,
                                            bias=bias,
                                            bn=bn,
                                            activation_fn=activation_fn)

            x = tf.nn.tanh(x, name="X_G")

            return x
Exemple #9
0
    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            x = layers.conv2d_block(x,
                                    "D64_conv2D_1",
                                    32,
                                    3,
                                    2,
                                    data_format=FLAGS.data_format,
                                    bn=False)
            x = layers.conv2d_block(x,
                                    "D64_conv2D_2",
                                    64,
                                    3,
                                    2,
                                    data_format=FLAGS.data_format)
            x = layers.conv2d_block(x,
                                    "D64_conv2D_3",
                                    128,
                                    3,
                                    2,
                                    data_format=FLAGS.data_format)
            x = layers.conv2d_block(x,
                                    "D64_conv2D_4",
                                    256,
                                    3,
                                    2,
                                    data_format=FLAGS.data_format)

            x_shape = x.get_shape().as_list()
            flat_dim = 1
            for d in x_shape[1:]:
                flat_dim *= d

            target_shape = (-1, flat_dim)
            x = layers.reshape(x, target_shape)

            x_mbd = layers.mini_batch_disc(x,
                                           num_kernels=100,
                                           dim_per_kernel=5,
                                           name="mbd64")
            x = tf.concat([x, x_mbd], axis=1)

            x = layers.linear(x, 1)

            return x
        def __call__(self, x, reuse=False):
            with tf.variable_scope(self.name) as scope:

                if reuse:
                    scope.reuse_variables()

                M, N = x.get_shape().as_list()[-2:]
                x = scattering.Scattering(M=M, N=N, J=2)(x)
                x = tf.contrib.layers.batch_norm(x, data_format=FLAGS.data_format, fused=True, scope="scat_bn")
                x = layers.conv2d_block("CONV2D", x, 64, 1, 1, p="SAME", data_format=FLAGS.data_format, bias=True, bn=False, activation_fn=tf.nn.relu)

                target_shape = (-1, 64 * 7 * 7)
                x = layers.reshape(x, target_shape)
                x = layers.linear(x, 512, name="dense1")
                x = tf.nn.relu(x)
                x = layers.linear(x, 10, name="dense2")

                return x
Exemple #11
0
    def __call__(self, x, reuse=False, mode="D"):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            x = layers.conv2d_block(x,
                                    "D16_conv2D_1",
                                    32,
                                    3,
                                    2,
                                    data_format=FLAGS.data_format,
                                    bn=False)
            x = layers.conv2d_block(x,
                                    "D16_conv2D_2",
                                    16,
                                    3,
                                    2,
                                    data_format=FLAGS.data_format)

            x_feat = tf.identity(x, "x_feat16")

            x_shape = x.get_shape().as_list()
            flat_dim = 1
            for d in x_shape[1:]:
                flat_dim *= d

            target_shape = (-1, flat_dim)
            x = layers.reshape(x, target_shape)

            x = layers.linear(x, 1)

            x_mbd = layers.mini_batch_disc(x,
                                           num_kernels=100,
                                           dim_per_kernel=5,
                                           name="mbd16")
            x = tf.concat([x, x_mbd], axis=1)

            if mode == "D":
                return x

            else:
                return x_feat, x
    def __call__(self, x, reuse=False, output_name=None):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            ##################
            # Encoding part
            ##################

            # First conv
            x = layers.conv2d(x, "D_conv2D1", FLAGS.channels, 32, 3, 1, "SAME", data_format=FLAGS.data_format)
            x = tf.nn.elu(x)

            # Conv blocks
            x = layers.D_conv2d_block(x, "D_enc_conv2D2", 64, 3, data_format=FLAGS.data_format)
            x = layers.D_conv2d_block(x, "D_enc_conv2D3", 128, 3, data_format=FLAGS.data_format)
            x = layers.D_conv2d_block(x, "D_enc_conv2D4", 256, 3, data_format=FLAGS.data_format)
            x = layers.D_conv2d_block(x, "D_enc_conv2D5", 256, 3, data_format=FLAGS.data_format)

            # strides = [1,1,1,1]
            # if FLAGS.data_format == "NCHW":
            #     ksize = [1,1,4,4]
            # else:
            #     ksize = [1,4,4,1]

            # x = tf.nn.avg_pool(x, ksize, strides, "VALID", data_format=FLAGS.data_format)
            # x = tf.squeeze(x, name="D_output")

            # ALTERNATIVELY:
            # Flatten
            batch_size = tf.shape(x)[0]
            other_dims = x.get_shape().as_list()[1:]
            prod_dim = 1
            for d in other_dims:
                prod_dim *= d
            x = layers.reshape(x, (batch_size, prod_dim))


            # Linear
            x = layers.linear(x, "D_FC1", 1, activation_fn=None)

            return x
Exemple #13
0
    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            x = tf.contrib.layers.batch_norm(x,
                                             data_format=self.data_format,
                                             fused=True,
                                             scope="scat_bn")
            x = layers.conv2d_block("D_conv2D1",
                                    x,
                                    32,
                                    3,
                                    2,
                                    p="SAME",
                                    data_format=self.data_format,
                                    bias=False,
                                    bn=False,
                                    activation_fn=layers.lrelu)
            x = layers.conv2d_block("D_conv2D2",
                                    x,
                                    64,
                                    3,
                                    2,
                                    p="SAME",
                                    data_format=self.data_format,
                                    bias=False,
                                    bn=True,
                                    activation_fn=layers.lrelu)
            # x = layers.conv2d_block("D_conv2D3", x, 128, 3, 2, p="SAME", data_format=self.data_format, bias=True, bn=True, activation_fn=layers.lrelu)

            x_shape = x.get_shape().as_list()
            target_shape = (-1, x_shape[-1] * x_shape[-2] * x_shape[-3])
            x = layers.reshape(x, target_shape)

            x = layers.linear(x, 1, name='dense2')

            return x
Exemple #14
0
    def __combine_heads(x):
        """
        Reshape the last two dimensions of inpunt tensor x so that it becomes
        one dimension.

        Args:
            x(Tensor): a 4-D input Tensor with shape
                       [bs, num_heads, max_sequence_length, hidden_dim].

        Returns:
            Tensor: a Tensor with shape
                    [bs, max_sequence_length, num_heads * hidden_dim].
        """

        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
        return layers.reshape(x=trans_x,
                              shape=map(int, [
                                  trans_x.shape[0], trans_x.shape[1],
                                  trans_x.shape[2] * trans_x.shape[3]
                              ]))
    def __call__(self, x, reuse=False, output_name=None):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            # Initial dense multiplication
            x = layers.linear(x, "G_FC1", self.nb_filters * 8 * 8)

            batch_size = tf.shape(x)[0]
            if FLAGS.data_format == "NHWC":
                target_shape = (batch_size, 8, 8, self.nb_filters)
            elif FLAGS.data_format == "NCHW":
                target_shape = (batch_size, self.nb_filters, 8, 8)

            x = layers.reshape(x, target_shape)
            # x = tf.contrib.layers.batch_norm(x, fused=True, data_format=FLAGS.data_format)
            x = tf.nn.elu(x)

            x = layers.dec_conv2d_block(x, "G_conv2D1", self.nb_filters, 3, data_format=FLAGS.data_format)
            x = layers.upsampleNN(x, "G_up1", 2, data_format=FLAGS.data_format)

            x = layers.dec_conv2d_block(x, "G_conv2D2", self.nb_filters, 3, data_format=FLAGS.data_format)
            x = layers.upsampleNN(x, "G_up2", 2, data_format=FLAGS.data_format)

            x = layers.dec_conv2d_block(x, "G_conv2D3", self.nb_filters, 3, data_format=FLAGS.data_format)
            x = layers.upsampleNN(x, "G_up3", 2, data_format=FLAGS.data_format)

            x = layers.dec_conv2d_block(x, "G_conv2D4", self.nb_filters, 3, data_format=FLAGS.data_format)

            # Last conv
            x = layers.conv2d(x, "G_conv2D5", self.nb_filters, FLAGS.channels, 3, 1, "SAME", data_format=FLAGS.data_format)

            x = tf.nn.tanh(x, name=output_name)

            return x
Exemple #16
0
    def __call__(self, x, reuse=False, output_name=None):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            ##################
            # Encoding part
            ##################

            # First conv
            x = layers.conv2d(x, "D_conv2D1", FLAGS.channels, self.nb_filters, 3, 1, "SAME", data_format=FLAGS.data_format)
            x = tf.nn.elu(x)

            # Conv blocks
            x = layers.enc_conv2d_block(x, "D_enc_conv2D2", self.nb_filters, 3, activation_fn=tf.nn.elu, data_format=FLAGS.data_format)
            x = layers.enc_conv2d_block(x, "D_enc_conv2D3", 2 * self.nb_filters, 3, activation_fn=tf.nn.elu, data_format=FLAGS.data_format)
            x = layers.enc_conv2d_block(x, "D_enc_conv2D4", 3 * self.nb_filters, 3, activation_fn=tf.nn.elu, data_format=FLAGS.data_format)
            x = layers.enc_conv2d_block(x, "D_enc_conv2D5", 4 * self.nb_filters, 3, activation_fn=tf.nn.elu, data_format=FLAGS.data_format, downsampling=False)

            # Flatten
            batch_size = tf.shape(x)[0]
            other_dims = x.get_shape().as_list()[1:]
            prod_dim = 1
            for d in other_dims:
                prod_dim *= d
            x = layers.reshape(x, (batch_size, prod_dim))

            # Linear
            x = layers.linear(x, "D_FC1", self.h_dim, activation_fn=None)

            ##################
            # Decoding part
            ##################

            x = layers.linear(x, "D_FC2", self.nb_filters * 8 * 8)

            batch_size = tf.shape(x)[0]
            if FLAGS.data_format == "NHWC":
                target_shape = (batch_size, 8, 8, self.nb_filters)
            elif FLAGS.data_format == "NCHW":
                target_shape = (batch_size, self.nb_filters, 8, 8)

            x = layers.reshape(x, target_shape)
            x = tf.contrib.layers.batch_norm(x, fused=True, data_format=FLAGS.data_format)
            x = tf.nn.elu(x)

            x = layers.dec_conv2d_block(x, "D_dec_conv2D1", self.nb_filters, 3, data_format=FLAGS.data_format)
            x = layers.upsampleNN(x, "D_up1", 2, data_format=FLAGS.data_format)

            x = layers.dec_conv2d_block(x, "D_dec_conv2D2", self.nb_filters, 3, data_format=FLAGS.data_format)
            x = layers.upsampleNN(x, "D_up2", 2, data_format=FLAGS.data_format)

            x = layers.dec_conv2d_block(x, "D_dec_conv2D3", self.nb_filters, 3, data_format=FLAGS.data_format)
            x = layers.upsampleNN(x, "D_up3", 2, data_format=FLAGS.data_format)

            x = layers.dec_conv2d_block(x, "D_dec_conv2D4", self.nb_filters, 3, data_format=FLAGS.data_format)

            # Last conv
            x = layers.conv2d(x, "D_dec_conv2D5", self.nb_filters, FLAGS.channels, 3, 1, "SAME", data_format=FLAGS.data_format)
            x = tf.nn.tanh(x, name=output_name)

            return x
    def __call__(self, x, reuse=False, output_name=None):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            ##################
            # Encoding part
            ##################

            # First conv
            x = layers.conv2d(x, "D_conv2D1", FLAGS.channels, self.nb_filters, 3, 1, "SAME", data_format=FLAGS.data_format)
            x = tf.nn.elu(x)

            # Conv blocks
            x = layers.enc_conv2d_block(x, "D_enc_conv2D2", self.nb_filters, 3, activation_fn=tf.nn.elu, data_format=FLAGS.data_format)
            x = layers.enc_conv2d_block(x, "D_enc_conv2D3", 2 * self.nb_filters, 3, activation_fn=tf.nn.elu, data_format=FLAGS.data_format)
            x = layers.enc_conv2d_block(x, "D_enc_conv2D4", 3 * self.nb_filters, 3, activation_fn=tf.nn.elu, data_format=FLAGS.data_format)
            x = layers.enc_conv2d_block(x, "D_enc_conv2D5", 4 * self.nb_filters, 3, activation_fn=tf.nn.elu, data_format=FLAGS.data_format, downsampling=False)

            # Flatten
            batch_size = tf.shape(x)[0]
            other_dims = x.get_shape().as_list()[1:]
            prod_dim = 1
            for d in other_dims:
                prod_dim *= d
            x = layers.reshape(x, (batch_size, prod_dim))

            # Linear
            x = layers.linear(x, "D_FC1", self.h_dim, activation_fn=None)

            ##################
            # Decoding part
            ##################

            x = layers.linear(x, "D_FC2", self.nb_filters * 8 * 8)

            batch_size = tf.shape(x)[0]
            if FLAGS.data_format == "NHWC":
                target_shape = (batch_size, 8, 8, self.nb_filters)
            elif FLAGS.data_format == "NCHW":
                target_shape = (batch_size, self.nb_filters, 8, 8)

            x = layers.reshape(x, target_shape)
            # x = tf.contrib.layers.batch_norm(x, fused=True, data_format=FLAGS.data_format)
            x = tf.nn.elu(x)

            x = layers.dec_conv2d_block(x, "D_dec_conv2D1", self.nb_filters, 3, data_format=FLAGS.data_format)
            x = layers.upsampleNN(x, "D_up1", 2, data_format=FLAGS.data_format)

            x = layers.dec_conv2d_block(x, "D_dec_conv2D2", self.nb_filters, 3, data_format=FLAGS.data_format)
            x = layers.upsampleNN(x, "D_up2", 2, data_format=FLAGS.data_format)

            x = layers.dec_conv2d_block(x, "D_dec_conv2D3", self.nb_filters, 3, data_format=FLAGS.data_format)
            x = layers.upsampleNN(x, "D_up3", 2, data_format=FLAGS.data_format)

            x = layers.dec_conv2d_block(x, "D_dec_conv2D4", self.nb_filters, 3, data_format=FLAGS.data_format)

            # Last conv
            x = layers.conv2d(x, "D_dec_conv2D5", self.nb_filters, FLAGS.channels, 3, 1, "SAME", data_format=FLAGS.data_format)
            x = tf.nn.tanh(x, name=output_name)

            return x
    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                # list_v = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)
                # for v in list_v:
                #     print v
                # print
                # print
                # for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS):
                #     print v
                # import ipdb; ipdb.set_trace()
                scope.reuse_variables()

            # Store all layers in a dict
            d = collections.OrderedDict()

            # Initial dense multiplication
            x = layers.linear(x, self.filters * self.start_dim * self.start_dim)

            # Reshape to image format
            if self.data_format == "NCHW":
                target_shape = (self.batch_size, self.filters, self.start_dim, self.start_dim)
            else:
                target_shape = (self.batch_size, self.start_dim, self.start_dim, self.filters)

            x = layers.reshape(x, target_shape)
            x = tf.contrib.layers.batch_norm(x, fused=True)
            x = tf.nn.relu(x)

            import ipdb; ipdb.set_trace()

            # # Conv2D + Phase shift blocks
            # x = layers.conv2d_block("conv2D_1_1", x, 512, 3, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bias=False, bn=True, activation_fn=layers.lrelu)
            # x = layers.conv2d_block("conv2D_1_2", x, 512, 3, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bias=False, bn=False, activation_fn=layers.lrelu)
            # x = layers.phase_shift(x, upsampling_factor=2, name="PS1")

            # x = layers.conv2d_block("conv2D_2_1", x, 256, 3, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bias=False, bn=False, activation_fn=layers.lrelu)
            # x = layers.conv2d_block("conv2D_2_2", x, 256, 3, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bias=False, bn=False, activation_fn=layers.lrelu)
            # x = layers.phase_shift(x, upsampling_factor=2, name="PS2")

            # x = layers.conv2d_block("conv2D_3", x, 1, 1, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bn=False)

            # # Upsampling2D + conv blocks
            # for idx, (f, k, s, p) in enumerate(zip(self.list_filters, self.list_kernel_size, self.list_strides, self.list_padding)):
            #     name = "upsample2D_%s" % idx
            #     if idx == len(self.list_filters) - 1:
            #         bn = False
            #     else:
            #         bn = True
            #     x = layers.upsample2d_block(name, x, f, k, s, p, data_format=self.data_format, bn=bn, activation_fn=layers.lrelu)

            # Transposed conv blocks
            for idx, (f, k, s, p) in enumerate(zip(self.list_filters, self.list_kernel_size, self.list_strides, self.list_padding)):
                img_size = self.start_dim * (2 ** (idx + 1))
                if self.data_format == "NCHW":
                    output_shape = (self.batch_size, f, img_size, img_size)
                else:
                    output_shape = (self.batch_size, img_size, img_size, f)
                name = "deconv2D_%s" % idx
                if idx == len(self.list_filters) - 1:
                    bn = False
                else:
                    bn = True
                x = layers.deconv2d_block(name, x, output_shape, k, s, p, data_format=self.data_format, bn=bn)

            x = tf.nn.tanh(x, name="X_G")

            return x
Exemple #19
0
def scaled_dot_product_attention(queries,
                                 keys,
                                 values,
                                 num_heads=1,
                                 dropout_rate=0.):
    """
    The dot-product attention.

    Attention mechanism can be seen as mapping a query and a set of key-value
    pairs to an output. The output is computed as a weighted sum of the values,
    where the weight assigned to each value is computed by a compatibility
    function (dot-product here) of the query with the corresponding key.

    The dot-product attention can be implemented through (batch) matrix
    multipication as follows:

        .. math::

            Attention(Q, K, V)= softmax(QK^\mathrm{T})V

    Refer to `Attention Is All You Need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.

    Args:
        queries (Variable): The input variable which should be a 3-D Tensor.
        keys (Variable): The input variable which should be a 3-D Tensor.
        values (Variable): The input variable which should be a 3-D Tensor.
        num_heads (int): Head number to compute the scaled dot product
            attention. Default: 1.
        dropout_rate (float): The dropout rate to drop the attention weight.
            Default: 0.0.

    Returns:
        Variable: A 3-D Tensor computed by multi-head scaled dot product\
            attention.

    Raises:
        ValueError: If input queries, keys, values are not 3-D Tensors.

    NOTES:
        1. When num_heads > 1, three linear projections are learned respectively
           to map input queries, keys and values into queries', keys' and values'.
           queries', keys' and values' have the same shapes with queries, keys
           and values.
        2. When num_heads == 1, scaled_dot_product_attention has no learnable
           parameters.

    Examples:
        .. code-block:: python

            queries = fluid.layers.data(name="queries",
                                        shape=[3, 5, 9],
                                        dtype="float32",
                                        append_batch_size=False)
            queries.stop_gradient = False
            keys = fluid.layers.data(name="keys",
                                     shape=[3, 6, 9],
                                     dtype="float32",
                                     append_batch_size=False)
            keys.stop_gradient = False
            values = fluid.layers.data(name="values",
                                       shape=[3, 6, 10],
                                       dtype="float32",
                                       append_batch_size=False)
            values.stop_gradient = False
            contexts = fluid.nets.scaled_dot_product_attention(queries, keys, values)
            contexts.shape  # [3, 5, 10]
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
            "Inputs quries, keys and values should all be 3-D tensors.")

    if queries.shape[-1] != keys.shape[-1]:
        raise ValueError(
            "The hidden size of queries and keys should be the same.")
    if keys.shape[-2] != values.shape[-2]:
        raise ValueError(
            "The max sequence length in query batch and in key batch "
            "should be the same.")
    if keys.shape[-1] % num_heads != 0:
        raise ValueError("The hidden size of keys (%d) must be divisible "
                         "by the number of attention heads (%d)." %
                         (keys.shape[-1], num_heads))
    if values.shape[-1] % num_heads != 0:
        raise ValueError("The hidden size of values (%d) must be divisible "
                         "by the number of attention heads (%d)." %
                         (values.shape[-1], num_heads))

    def __compute_qkv(queries, keys, values, num_heads):
        """
        Add linear projection to queries, keys, and values.

        Args:
            queries(Tensor): a 3-D input Tensor.
            keys(Tensor): a 3-D input Tensor.
            values(Tensor): a 3-D input Tensor.
            num_heads(int): The number of heads. Linearly project the inputs
                            ONLY when num_heads > 1.

        Returns:
            Tensor: linearly projected output Tensors: queries', keys' and
                    values'. They have the same shapes with queries, keys and
                    values.
        """

        if num_heads == 1:
            return queries, keys, values

        q = layers.fc(input=queries,
                      size=queries.shape[-1],
                      num_flatten_dims=2)
        k = layers.fc(input=keys, size=keys.shape[-1], num_flatten_dims=2)
        v = layers.fc(input=values, size=values.shape[-1], num_flatten_dims=2)
        return q, k, v

    def __split_heads(x, num_heads):
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions.

        Args:
            x(Tensor): a 3-D input Tensor.
            num_heads(int): The number of heads.

        Returns:
            Tensor: a Tensor with shape [..., n, m/num_heads], where m is size
                    of the last dimension of x.
        """
        if num_heads == 1:
            return x

        hidden_size = x.shape[-1]
        # reshape the 3-D input: [batch_size, max_sequence_length, hidden_dim]
        # into a 4-D output:
        # [batch_size, max_sequence_length, num_heads, hidden_size_per_head].
        reshaped = layers.reshape(x=x,
                                  shape=list(x.shape[:-1]) +
                                  [num_heads, hidden_size // num_heads])

        # permuate the dimensions into:
        # [batch_size, num_heads, max_sequence_len, hidden_size_per_head]
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Reshape the last two dimensions of inpunt tensor x so that it becomes
        one dimension.

        Args:
            x(Tensor): a 4-D input Tensor with shape
                       [bs, num_heads, max_sequence_length, hidden_dim].

        Returns:
            Tensor: a Tensor with shape
                    [bs, max_sequence_length, num_heads * hidden_dim].
        """

        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
        return layers.reshape(x=trans_x,
                              shape=map(int, [
                                  trans_x.shape[0], trans_x.shape[1],
                                  trans_x.shape[2] * trans_x.shape[3]
                              ]))

    q, k, v = __compute_qkv(queries, keys, values, num_heads)

    q = __split_heads(q, num_heads)
    k = __split_heads(k, num_heads)
    v = __split_heads(v, num_heads)

    key_dim_per_head = keys.shape[-1] // num_heads
    scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
    product = layers.matmul(x=k, y=scaled_q, transpose_y=True)

    weights = layers.reshape(x=layers.reshape(x=product,
                                              shape=[-1, product.shape[-1]],
                                              act="softmax"),
                             shape=product.shape)
    if dropout_rate:
        weights = layers.dropout(weights,
                                 dropout_prob=dropout_rate,
                                 is_test=False)
    ctx_multiheads = layers.matmul(weights, v)
    return __combine_heads(ctx_multiheads)