def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes): """Test with a mask tensor.""" test_layer = attention.ReuseMultiHeadAttention( num_heads=2, key_dim=2, attention_axes=attention_axes) batch_size, hidden_size = 3, 8 # Generate data for the input (non-mask) tensors. query_shape = [batch_size] + q_dims + [hidden_size] value_shape = [batch_size] + v_dims + [hidden_size] mask_shape = [batch_size] + mask_dims query = 10 * np.random.random_sample(query_shape) value = 10 * np.random.random_sample(value_shape) # Invoke the data with a random set of mask data. This should mask at least # one element. mask_data = np.random.randint(2, size=mask_shape).astype("bool") # Invoke the same data, but with a null mask (where no elements are masked). null_mask_data = np.ones(mask_shape) # Because one data is masked and one is not, the outputs should not be the # same. query_tensor = tf.keras.Input(query_shape[1:], name="query") value_tensor = tf.keras.Input(value_shape[1:], name="value") mask_tensor = tf.keras.Input(mask_shape[1:], name="mask") output = test_layer(query=query_tensor, value=value_tensor, attention_mask=mask_tensor) model = tf.keras.Model([query_tensor, value_tensor, mask_tensor], output) self.assertNotAllClose(model.predict([query, value, mask_data]), model.predict([query, value, null_mask_data]))
def test_non_masked_self_attention_with_relative_pe( self, reuse_attention, pe_max_seq_length): """Test with one input (self-attenntion) and no mask tensor.""" test_layer = attention.ReuseMultiHeadAttention( num_heads=12, key_dim=64, reuse_attention=reuse_attention, use_relative_pe=True, pe_max_seq_length=pe_max_seq_length) # Create a 3-dimensional input (the first dimension is implicit). query = tf.keras.Input(shape=(40, 80)) reuse_scores = tf.keras.Input(shape=(12, 40, 40)) output = test_layer(query, query, reuse_attention_scores=reuse_scores) self.assertEqual(output.shape.as_list(), [None, 40, 80]) query = tf.keras.Input(shape=(30, 80)) reuse_scores = tf.keras.Input(shape=(12, 30, 30)) output = test_layer(query, query, reuse_attention_scores=reuse_scores) self.assertEqual(output.shape.as_list(), [None, 30, 80]) query = tf.keras.Input(shape=(30, 80)) key = tf.keras.Input(shape=(20, 80)) reuse_scores = tf.keras.Input(shape=(12, 30, 20)) output = test_layer(query, key, reuse_attention_scores=reuse_scores) self.assertEqual(output.shape.as_list(), [None, 30, 80]) query = tf.keras.Input(shape=(50, 80)) key = tf.keras.Input(shape=(60, 80)) reuse_scores = tf.keras.Input(shape=(12, 50, 60)) output = test_layer(query, key, reuse_attention_scores=reuse_scores) self.assertEqual(output.shape.as_list(), [None, 50, 80])
def test_non_masked_self_attention(self): """Test with one input (self-attenntion) and no mask tensor.""" test_layer = attention.ReuseMultiHeadAttention(num_heads=12, key_dim=64) # Create a 3-dimensional input (the first dimension is implicit). query = tf.keras.Input(shape=(40, 80)) output = test_layer(query, query) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): """Test attention outputs with coefficients.""" test_layer = attention.ReuseMultiHeadAttention(num_heads=12, key_dim=64) # Create a 3-dimensional input (the first dimension is implicit). query = tf.keras.Input(shape=(40, 80)) output, coef = test_layer(query, query, return_attention_scores=True) self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
def test_initializer(self): """Test with a specified initializer.""" test_layer = attention.ReuseMultiHeadAttention( num_heads=12, key_dim=64, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=0.02)) # Create a 3-dimensional input (the first dimension is implicit). query = tf.keras.Input(shape=(40, 80)) output = test_layer(query, query) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_non_masked_attention(self, value_dim, output_shape, output_dims): """Test that the attention layer can be created without a mask tensor.""" test_layer = attention.ReuseMultiHeadAttention( num_heads=12, key_dim=64, value_dim=value_dim, output_shape=output_shape) # Create a 3-dimensional input (the first dimension is implicit). query = tf.keras.Input(shape=(40, 80)) value = tf.keras.Input(shape=(20, 80)) output = test_layer(query=query, value=value) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_dropout(self): test_layer = attention.ReuseMultiHeadAttention(num_heads=2, key_dim=2, dropout=0.5) # Generate data for the input (non-mask) tensors. from_data = tf.keras.backend.ones(shape=(32, 4, 8)) to_data = tf.keras.backend.ones(shape=(32, 2, 8)) train_out = test_layer(from_data, to_data, None, None, None, True) test_out = test_layer(from_data, to_data, None, None, None, False) # Output should be close when not in training mode, # and should not be close when enabling dropout in training mode. self.assertNotAllClose(tf.keras.backend.eval(train_out), tf.keras.backend.eval(test_out))
def test_masked_attention_with_scores(self): """Test with a mask tensor.""" test_layer = attention.ReuseMultiHeadAttention(num_heads=2, key_dim=2) # Create a 3-dimensional input (the first dimension is implicit). batch_size = 3 query = tf.keras.Input(shape=(4, 8)) value = tf.keras.Input(shape=(2, 8)) mask_tensor = tf.keras.Input(shape=(4, 2)) output = test_layer(query=query, value=value, attention_mask=mask_tensor) # Create a model containing the test layer. model = tf.keras.Model([query, value, mask_tensor], output) # Generate data for the input (non-mask) tensors. from_data = 10 * np.random.random_sample((batch_size, 4, 8)) to_data = 10 * np.random.random_sample((batch_size, 2, 8)) # Invoke the data with a random set of mask data. This should mask at least # one element. mask_data = np.random.randint(2, size=(batch_size, 4, 2)) masked_output_data = model.predict([from_data, to_data, mask_data]) # Invoke the same data, but with a null mask (where no elements are masked). null_mask_data = np.ones((batch_size, 4, 2)) unmasked_output_data = model.predict( [from_data, to_data, null_mask_data]) # Because one data is masked and one is not, the outputs should not be the # same. self.assertNotAllClose(masked_output_data, unmasked_output_data) # Create a model containing attention scores. output, scores = test_layer(query=query, value=value, attention_mask=mask_tensor, return_attention_scores=True) model = tf.keras.Model([query, value, mask_tensor], [output, scores]) masked_output_data_score, masked_score = model.predict( [from_data, to_data, mask_data]) unmasked_output_data_score, unmasked_score = model.predict( [from_data, to_data, null_mask_data]) self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score) self.assertAllClose(masked_output_data, masked_output_data_score) self.assertAllClose(unmasked_output_data, unmasked_output_data_score) self.assertNotAllClose(masked_score, unmasked_score)
def test_masked_attention(self, use_bias, reuse_attention): """Test with a mask tensor.""" test_layer = attention.ReuseMultiHeadAttention( num_heads=2, key_dim=2, use_bias=use_bias, reuse_attention=reuse_attention) # Create a 3-dimensional input (the first dimension is implicit). batch_size = 3 query = tf.keras.Input(shape=(4, 8)) value = tf.keras.Input(shape=(2, 8)) mask_tensor = tf.keras.Input(shape=(4, 2)) reuse_attention_scores = tf.keras.Input(shape=(2, 4, 2)) output = test_layer(query=query, value=value, attention_mask=mask_tensor, reuse_attention_scores=reuse_attention_scores) # Create a model containing the test layer. model = tf.keras.Model( [query, value, mask_tensor, reuse_attention_scores], output) # Generate data for the input (non-mask) tensors. from_data = 10 * np.random.random_sample((batch_size, 4, 8)) to_data = 10 * np.random.random_sample((batch_size, 2, 8)) reuse_scores = np.random.random_sample((batch_size, 2, 4, 2)) # Invoke the data with a random set of mask data. This should mask at least # one element. mask_data = np.random.randint(2, size=(batch_size, 4, 2)) masked_output_data = model.predict( [from_data, to_data, mask_data, reuse_scores]) # Invoke the same data, but with a null mask (where no elements are masked). null_mask_data = np.ones((batch_size, 4, 2)) unmasked_output_data = model.predict( [from_data, to_data, null_mask_data, reuse_scores]) # Because one data is masked and one is not, the outputs should not be the # same. if reuse_attention == -1: self.assertAllEqual(masked_output_data, unmasked_output_data) else: self.assertNotAllClose(masked_output_data, unmasked_output_data) # Tests the layer with three inputs: Q, K, V. key = tf.keras.Input(shape=(2, 8)) output = test_layer(query, value=value, key=key, attention_mask=mask_tensor, reuse_attention_scores=reuse_attention_scores) model = tf.keras.Model( [query, value, key, mask_tensor, reuse_attention_scores], output) masked_output_data = model.predict( [from_data, to_data, to_data, mask_data, reuse_scores]) unmasked_output_data = model.predict( [from_data, to_data, to_data, null_mask_data, reuse_scores]) # Because one data is masked and one is not, the outputs should not be the # same. if reuse_attention == -1: self.assertAllEqual(masked_output_data, unmasked_output_data) else: self.assertNotAllClose(masked_output_data, unmasked_output_data) if reuse_attention > 0: self.assertLen(test_layer._output_dense, 2) if use_bias: if reuse_attention == 0: self.assertLen(test_layer._query_dense.trainable_variables, 2) self.assertLen(test_layer._output_dense[0].trainable_variables, 2) if len(test_layer._output_dense) == 2: self.assertLen(test_layer._output_dense[1].trainable_variables, 1) else: if reuse_attention == 0: self.assertLen(test_layer._query_dense.trainable_variables, 1) self.assertLen(test_layer._output_dense[0].trainable_variables, 1) if len(test_layer._output_dense) == 2: self.assertLen(test_layer._output_dense[1].trainable_variables, 1)
def build(self, input_shape): if isinstance(input_shape, tf.TensorShape): input_tensor_shape = input_shape elif isinstance(input_shape, (list, tuple)): input_tensor_shape = tf.TensorShape(input_shape[0]) else: raise ValueError( "The type of input shape argument is not supported, got: %s" % type(input_shape)) einsum_equation = "abc,cd->abd" if len(input_tensor_shape.as_list()) > 3: einsum_equation = "...bc,cd->...bd" hidden_size = input_tensor_shape[-1] if self._head_size is None: if hidden_size % self._num_heads != 0: raise ValueError( "The input size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, self._num_heads)) self._attention_head_size = int(hidden_size // self._num_heads) else: self._attention_head_size = self._head_size common_kwargs = dict(kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) self._attention_layer = attention.ReuseMultiHeadAttention( num_heads=self._num_heads, key_dim=self._attention_head_size, dropout=self._attention_dropout, use_bias=self._use_bias, kernel_initializer=self._attention_initializer, bias_initializer=tf_utils.clone_initializer( self._bias_initializer), attention_axes=self._attention_axes, reuse_attention=self._reuse_attention, use_relative_pe=self._use_relative_pe, pe_max_seq_length=self._pe_max_seq_length, name="self_attention", **common_kwargs) self._attention_dropout = tf.keras.layers.Dropout( rate=self._output_dropout) # Use float32 in layernorm for numeric stability. # It is probably safe in mixed_float16, but we haven't validated this yet. self._attention_layer_norm = (tf.keras.layers.LayerNormalization( name="self_attention_layer_norm", axis=-1, epsilon=self._norm_epsilon, dtype=tf.float32)) self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( einsum_equation, output_shape=(None, self._inner_dim), bias_axes="d", kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer( self._bias_initializer), name="intermediate", **common_kwargs) policy = tf.keras.mixed_precision.global_policy() if policy.name == "mixed_bfloat16": # bfloat16 causes BERT with the LAMB optimizer to not converge # as well, so we use float32. # TODO(b/154538392): Investigate this. policy = tf.float32 self._intermediate_activation_layer = tf.keras.layers.Activation( self._inner_activation, dtype=policy) self._inner_dropout_layer = tf.keras.layers.Dropout( rate=self._inner_dropout) self._output_dense = tf.keras.layers.experimental.EinsumDense( einsum_equation, output_shape=(None, hidden_size), bias_axes="d", name="output", kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer( self._bias_initializer), **common_kwargs) self._output_dropout = tf.keras.layers.Dropout( rate=self._output_dropout) # Use float32 in layernorm for numeric stability. self._output_layer_norm = tf.keras.layers.LayerNormalization( name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon, dtype=tf.float32) super(ReuseTransformer, self).build(input_shape)