def __init__(self, num_heads, num_units, attention_key_depth=None, attention_value_depth=None, output_depth=None, attention_dropout_rate=0.1, attention_type="dot_product", name=None): """ Initializes the multi head attention layer. Args: num_heads: A int scalar, the number of heads. num_units: A int scalar, the default units if other `depth` is not provided. attention_key_depth: A int scalar, the dimension for projected attention keys. If not provided, then use `num_units` as default. attention_value_depth: A int scalar, the dimension for projected attention values. If not provided, then use `num_units` as default. output_depth: A int scalar, the dimension for projected outputs. If not provided, then use `num_units` as default. attention_dropout_rate: A float scalar, the dropout rate for attention weight. attention_type: A string indicating the attention type. name: The name of the layer. """ self._params = extract_constructor_params(locals(), verbose=False) super(MultiHeadAttention, self).__init__(name=name) self._num_heads = num_heads self._num_units = num_units self._attention_key_depth = attention_key_depth or num_units self._attention_value_depth = attention_value_depth or num_units self._output_depth = output_depth or num_units self._attention_dropout_rate = attention_dropout_rate self._attention_type = attention_type if self._attention_key_depth % self._num_heads != 0: raise ValueError( "query depth ({}) must be divisible by the number of " "attention heads ({}).".format(self._attention_key_depth, self._num_heads)) if self._attention_value_depth % self._num_heads != 0: raise ValueError( "value depth ({}) must be divisible by the number of " "attention heads ({}).".format(self._attention_value_depth, self._num_heads)) # pre-create output transform layer self._output_transform_layer = MultiHeadDenseLayer( output_units=self._output_depth, num_heads=self._num_heads, kernel_initializer="glorot_uniform", is_output_transform=True, use_bias=True, name="output_transform")
def build(self, input_shape): """ Builds the layer. Layers for linearly projecting the queries, keys, and values.""" self._q_transform_layer = MultiHeadDenseLayer( output_units=self._attention_key_depth, num_heads=self._num_heads, kernel_initializer="glorot_uniform", is_output_transform=False, use_bias=True, name="q_transform") self._kv_transform_layer = MultiHeadDenseLayer( output_units=[ self._attention_key_depth, self._attention_value_depth ], num_heads=self._num_heads, kernel_initializer="glorot_uniform", is_output_transform=False, use_bias=True, name="kv_transform") self.add_activation_quantizer(name="output", activation="act") self.add_activation_quantizer(name="softmax", activation="softmax") self.built = True
def build(self, input_shape): self._qkv_transform_layer = MultiHeadDenseLayer( output_units=[ self._attention_key_depth, self._attention_key_depth, self._attention_value_depth ], num_heads=self._num_heads, kernel_initializer="glorot_uniform", is_output_transform=False, use_bias=True, name="qkv_transform") self.add_activation_quantizer(name="output", activation="act") self.add_activation_quantizer(name="softmax", activation="softmax") self.built = True
def test_multihead_dense(): num_heads = 3 output_size = 6 non_out_layer = MultiHeadDenseLayer(output_size, num_heads, use_bias=True, is_output_transform=False, name="nonoutput_transform") inputs = tf.convert_to_tensor(numpy.random.randn(2, 3, 6), dtype=tf.float32) layer_out = non_out_layer(inputs) kernel, bias = None, None for w in non_out_layer.trainable_weights: if "kernel" in w.name: kernel = w else: bias = w manual_out = tf.einsum("abc,cd->abd", inputs, kernel) + bias manual_out = tf.reshape( manual_out, tf.concat( [tf.shape(manual_out)[:-1], [num_heads, output_size // num_heads]], axis=0)) assert numpy.sum((manual_out.numpy() - layer_out.numpy())**2) < 1e-9 num_inputs_per_head = 5 out_layer = MultiHeadDenseLayer(output_size, num_heads, use_bias=True, is_output_transform=True, name="output_transform") inputs = tf.convert_to_tensor(numpy.random.randn(1, 2, num_heads, num_inputs_per_head), dtype=tf.float32) layer_out = out_layer(inputs) kernel, bias = None, None for w in out_layer.trainable_weights: if "kernel" in w.name: kernel = w else: bias = w manual_out = tf.matmul( tf.reshape(inputs, tf.concat([tf.shape(inputs)[:-2], [-1]], 0)), kernel) + bias assert numpy.sum((manual_out.numpy() - layer_out.numpy())**2) < 1e-9 output_size1 = 9 non_out_multi_layer = MultiHeadDenseLayer([output_size, output_size1], num_heads, use_bias=True, is_output_transform=False, name="nonoutput_transform") inputs = tf.convert_to_tensor(numpy.random.randn(2, 3, 6), dtype=tf.float32) layer_out0, layer_out1 = non_out_multi_layer(inputs) kernel, bias = None, None for w in non_out_multi_layer.trainable_weights: if "kernel" in w.name: kernel = w else: bias = w manual_out = tf.einsum("abc,cd->abd", inputs, kernel) + bias manual_out0, manual_out1 = tf.split(manual_out, [output_size, output_size1], axis=-1) manual_out0 = tf.reshape( manual_out0, tf.concat([ tf.shape(manual_out0)[:-1], [num_heads, output_size // num_heads] ], axis=0)) manual_out1 = tf.reshape( manual_out1, tf.concat([ tf.shape(manual_out1)[:-1], [num_heads, output_size1 // num_heads] ], axis=0)) assert numpy.sum((manual_out0.numpy() - layer_out0.numpy())**2) < 1e-9 assert numpy.sum((manual_out1.numpy() - layer_out1.numpy())**2) < 1e-9