Exemplo n.º 1
0
    def attention_block(self, entities, reuse, name="attention_block"):
        """Performs non-local pairwise relational computations.

        Args:
          entities: A tensor of shape (B, K, D) where K is the number of entities.
          reuse: Whether to reuse the weights.
          name: The name of the block.

        Returns:
          Updated entity representation (B, K, D)
        """
        # Estimate local dimensions to support background channel.
        k, z_dim = entities.get_shape().as_list()[1:3]

        r_entities = tf.reshape(entities, [self.batch_size * k, z_dim])

        with tf.variable_scope(name, reuse=reuse):
            queries = ops.layer_norm(
                tf.nn.relu(
                    ops.linear(r_entities, self.embedding_dim, scope="q_fc")),
                reuse, "q_ln")
            queries = tf.reshape(queries,
                                 [self.batch_size, k, self.embedding_dim])

            keys = ops.layer_norm(
                tf.nn.relu(
                    ops.linear(r_entities, self.embedding_dim, scope="k_fc")),
                reuse, "k_ln")
            keys = tf.reshape(keys, [self.batch_size, k, self.embedding_dim])

            values = ops.layer_norm(
                tf.nn.relu(
                    ops.linear(r_entities, self.embedding_dim, scope="v_fc")),
                reuse, "v_ln")
            values = tf.reshape(values,
                                [self.batch_size, k, self.embedding_dim])

            attention_weights = tf.matmul(queries,
                                          tf.transpose(keys, [0, 2, 1]))
            norm_attention_weights = tf.nn.softmax(
                attention_weights /
                tf.sqrt(tf.cast(self.embedding_dim, tf.float32)),
                axis=2)

            attention = tf.matmul(norm_attention_weights, values)
            r_attention = tf.reshape(attention,
                                     [self.batch_size * k, self.embedding_dim])

            # Project back to original space.
            u_entities = tf.nn.relu(ops.linear(r_attention, z_dim, "e_fc1"))
            u_entities = tf.nn.relu(ops.linear(u_entities, z_dim, "e_fc2"))
            u_entities = ops.layer_norm(u_entities + r_entities, reuse, "e_ln")

            return tf.reshape(u_entities, [self.batch_size, k, z_dim])
Exemplo n.º 2
0
    def aggregate_heads(self, heads, reuse, name="aggregate_heads"):
        """Returns the aggregated heads."""
        # Estimate local dimensions to support background channel.
        k, z_dim = heads[0].get_shape().as_list()[1:3]

        with tf.variable_scope(name, reuse=reuse):
            heads = tf.concat(heads, axis=2)
            heads_r = tf.reshape(heads,
                                 [self.batch_size * k, self.n_heads * z_dim])
            heads_a = tf.nn.relu(
                ops.linear(tf.concat(heads_r, axis=2), z_dim, "a_fc1"))
            heads_a = ops.layer_norm(heads_a, reuse, "a_ln")
            heads_a = tf.reshape(heads_a, [self.batch_size, k, z_dim])

            return heads_a
Exemplo n.º 3
0
def resnet_block(inputs, in_channels, out_channels, scale,
                 block_scope, is_training, reuse, discriminator_normalization,
                 is_gen_block):
  assert scale in ["up", "down", "none"]
  if inputs.get_shape().as_list()[-1] != in_channels:
    raise ValueError("Unexpected number of input channels.")

  # In SN paper, if they upscale in generator they do this in the first conv.
  # For discriminator downsampling happens after second conv.
  if is_gen_block:
    # Generator block
    scale1 = scale  # "up" or "none"
    scale2 = "none"
  else:
    # Discriminator block.
    scale1 = "none"
    scale2 = scale  # "down" or "none"

  print ("resnet_block, in=%d out=%d, scale=%s, scope=%s normalizer=%s" % (
      in_channels, out_channels, scale, block_scope,
      discriminator_normalization))
  print ("INPUTS: ", inputs.get_shape())
  with tf.variable_scope(block_scope, values=[inputs], reuse=reuse):
    output = inputs
    use_sn = discriminator_normalization == consts.SPECTRAL_NORM

    # Define the skip connection, ensure 'conv' is in the suffix, otherwise it
    # will not be regularized.

    shortcut = get_conv(
        output, in_channels, out_channels, scale,
        suffix="conv_shortcut", use_sn=use_sn)
    print ("SHORTCUT: ", shortcut.get_shape())

    # Apply batch norm in discriminator only if enabled.
    if is_gen_block or discriminator_normalization == consts.BATCH_NORM:
      output = batch_norm_resnet(output, is_training=is_training, scope="bn1")
    elif discriminator_normalization == consts.LAYER_NORM:
      output = ops.layer_norm(output, is_training=is_training, scope="ln1")

    output = tf.nn.relu(output)
    output = get_conv(
        output, in_channels, out_channels, scale1,
        suffix="conv1", use_sn=use_sn)
    print ("OUTPUT CONV1: ", output.get_shape())

    # Apply batch norm in discriminator only if enabled.
    if is_gen_block or discriminator_normalization == consts.BATCH_NORM:
      output = batch_norm_resnet(output, is_training=is_training, scope="bn2")
    elif discriminator_normalization == consts.LAYER_NORM:
      output = ops.layer_norm(output, is_training=is_training, scope="ln2")

    output = tf.nn.relu(output)
    output = get_conv(
        output, out_channels, out_channels, scale2,
        suffix="conv2", use_sn=use_sn)
    print ("OUTPUT CONV2: ", output.get_shape())

    # Combine skip-connection with the convolved part.
    output += shortcut

    return output