Esempio n. 1
0
def _linear_4d(inputs, output_size, bias, concat=True, data_format="NHWC"):
    data_format = check_data_format(data_format)
    channel_axis = 1 if data_format == "NCHW" else -1

    input_size = [item.get_shape()[channel_axis].value for item in inputs]

    outputs = []

    if concat:
        input_size = sum(input_size)
        inputs = tf.concat(inputs, channel_axis)

        shape = [input_size, output_size]
        matrix = tf.get_variable("matrix", shape)
        matrix = tf.expand_dims(tf.expand_dims(matrix, 0), 1)
        output = tf.nn.convolution(inputs, matrix, "VALID",
                                   data_format=data_format)
        outputs.append(output)
    else:
        for i in range(len(input_size)):
            shape = [input_size[i], output_size]
            name = "matrix_%d" % i
            matrix = tf.get_variable(name, shape)
            matrix = tf.expand_dims(tf.expand_dims(matrix, 0), 1)
            output = tf.nn.convolution(inputs, matrix, "VALID",
                                       data_format=data_format)
            outputs.append(output)

    output = tf.add_n(outputs)

    if bias is not None:
        bias = tf.get_variable("bias", [output_size])
        output = tf.nn.bias_add(output, bias, data_format=data_format)

    return output
Esempio n. 2
0
def layer_norm(inputs,
               epsilon=1e-6,
               data_format="NHWC",
               dtype=None,
               scope=None):
    with tf.variable_scope(scope,
                           default_name="layer_norm",
                           values=[inputs],
                           dtype=dtype):
        data_format = check_data_format(data_format)
        axis = 1 if data_format == "NCHW" else -1
        channel_size = inputs.get_shape().as_list()[axis]

        scale = tf.get_variable("scale",
                                shape=[channel_size],
                                initializer=tf.ones_initializer())

        offset = tf.get_variable("offset",
                                 shape=[channel_size],
                                 initializer=tf.zeros_initializer())

        mean = tf.reduce_mean(inputs, axis=axis, keep_dims=True)
        variance = tf.reduce_mean(tf.square(inputs - mean),
                                  axis=axis,
                                  keep_dims=True)

        norm_inputs = (inputs - mean) * tf.rsqrt(variance + epsilon)

        return norm_inputs * scale + offset
Esempio n. 3
0
def multihead_attention(query,
                        memory,
                        bias,
                        key_size,
                        value_size,
                        output_size,
                        num_heads,
                        keep_prob=None,
                        data_format="NHWC",
                        attention_function="dot_product",
                        summaries=False,
                        image_shapes=None,
                        dtype=None,
                        scope=None):
    """ Multihead scaled-dot-product attention with input/output
        transformations.

    Args:
        query: a Tensor with shape [batch, length_q, channels] if
            data_format is `NHWC`, [batch, channels, length_q] if
            data_format is `NCHW`
        memory: a Tensor with shape [batch, length_m, channels] if
            data_format is `NHWC`, [batch, channels, length_q] if
            data_format is `NCHW`
        bias: bias Tensor (see attention_bias())
        key_size: an integer
        value_size: an integer
        output_size: an integer
        num_heads: an integer dividing total_key_depth and total_value_depth
        keep_prob: a floating point number
        summaries: a boolean
        image_shapes: optional tuple of integer scalars.
            see comments for attention_image_summary()
        data_format: "NHWC" or "NCHW"
        attention_function: "dot_product" or "additive"
        dtype: an optional instance of tf.DType
        scope: an optional string

    Returns:
        A Tensor.
    """
    if key_size % num_heads != 0:
        raise ValueError("Key size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (key_size, num_heads))

    if value_size % num_heads != 0:
        raise ValueError("Value size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (value_size, num_heads))

    with tf.variable_scope(scope,
                           default_name="multihead_attention",
                           values=[query, memory],
                           dtype=dtype):
        data_format = check_data_format(data_format)
        axis = 1 if data_format is "NCHW" else 2

        if memory is None:
            # self attention
            size = key_size * 2 + value_size
            combined = linear(query,
                              size,
                              True,
                              True,
                              data_format=data_format,
                              scope="qkv_transform")
            q, k, v = tf.split(combined, [key_size, key_size, value_size],
                               axis=axis)
        else:
            q = linear(query,
                       key_size,
                       True,
                       data_format=data_format,
                       scope="q_transform")
            combined = linear(memory,
                              key_size + value_size,
                              True,
                              data_format=data_format,
                              scope="kv_transform")
            k, v = tf.split(combined, [key_size, value_size], axis=axis)

        # split heads
        q = _split_heads(q, num_heads, data_format=data_format)
        k = _split_heads(k, num_heads, data_format=data_format)
        v = _split_heads(v, num_heads, data_format=data_format)

        # scale query
        if attention_function == "dot_product":
            key_depth_per_head = key_size // num_heads
            q *= key_depth_per_head**-0.5

            # attention
            x = dot_product_attention(q, k, v, bias, keep_prob, summaries,
                                      image_shapes)
        elif attention_function == "additive":
            x = additive_attention(q, k, v, bias, keep_prob, summaries,
                                   image_shapes)
        else:
            raise ValueError("Unknown attention function")

        # combine heads
        x = _combine_heads(x, data_format=data_format)

        x = linear(x,
                   output_size,
                   True,
                   data_format=data_format,
                   scope="output_transform")
        return x