def test_force_recomputation(self): """Tests that an error is thrown when there is no recompute context.""" dropout = recomputing_dropout.RecomputingDropout( 0.4, force_recomputation=True) with self.assertRaises(ValueError) as assert_raises_context: dropout(np.random.normal(size=(2, 8)), training=True) self.assertContainsExactSubsequence( str(assert_raises_context.exception), 'RecomputeContext is required')
def __init__(self, inner_layer=None, normalization_layer=None, dropout_probability=0.0, use_pre_activation_order=False, inner_intermediate_size=None, inner_activation='relu', inner_kernel_initializer=None, name='residual_block', **kwargs): """Init. Args: inner_layer: Keras layer to apply as the inner layer in the residual block. The output of the layer must have the same shape as the input. By default, a 2-layer fully-connected network (via `DenseLayers`) is created based on the `inner_...` arguments below. normalization_layer: Normalization layer to apply. If `inner_layer` expects multiple inputs/outputs, then this should be a sequence of layers, one for each input. By default this is initialized to a single `tf.keras.layers.LayerNormalization` layer, so it must be given when expecting multiple `inner_layer` inputs. dropout_probability: The probability of dropping out a value when applying dropout for the block. use_pre_activation_order: If True, use "pre-activation" order (see class docstring for details). inner_intermediate_size: Size of intermediate fully-connected layer. Defaults to the input layer size. Ignored if `inner_layer` is not None. inner_activation: Activation function for the intermediate layer. Ignored if `inner_layer` is not None. inner_kernel_initializer: Initializer to use for fully-connected kernel weights. Bias weights are always initialized to 0. Ignored if `inner_layer` is not None. name: Name of the layer. **kwargs: Forwarded to super. """ super(ResidualBlock, self).__init__(name=name, **kwargs) if normalization_layer is None: normalization_layer = tf.keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name='layer_norm') if isinstance(normalization_layer, Sequence): normalization_layers = normalization_layer else: normalization_layers = [normalization_layer] # Inner layer may be created later. Assign `normalization_layers` attribute # first, so that the variable order remains the same regardless. self.normalization_layers = normalization_layers self.inner_layer = inner_layer self.dropout_probability = dropout_probability self.use_pre_activation_order = use_pre_activation_order self.inner_intermediate_size = inner_intermediate_size self.inner_activation = inner_activation self.inner_kernel_initializer = inner_kernel_initializer self.dropout_layers = [ recomputing_dropout.RecomputingDropout(rate=dropout_probability) for _ in self.normalization_layers ]
def make_head(): inputs = tf.keras.Input(shape=(5,)) x = tf.keras.layers.Dense( 3, activation='tanh', name='dense', bias_initializer='glorot_normal')( inputs) x = recomputing_dropout.RecomputingDropout(0.45)(x) outputs = { 'head_mask': tf.cast(tf.math.not_equal(x, 0), tf.float32), 'y': tf.reduce_sum(x), } return tf.keras.Model(inputs, outputs, name='head')
def test_recompute_grad(self): """Tests that the gradient is computed correctly with recompute_grad.""" dense = tf.keras.layers.Dense(10, input_shape=(8,)) dropout = recomputing_dropout.RecomputingDropout( 0.4, force_recomputation=True) @recompute_grad.recompute_grad def recompute_dense_dropout(x): return dropout(dense(x), training=True) # Define the model using dropout. def f(x): with tf.GradientTape() as tape: h1 = recompute_dense_dropout(x) h2 = recompute_dense_dropout(x) y = tf.math.reduce_sum(h1 + h2) return (tf.cast(tf.math.not_equal(h1, 0), tf.float32), tf.cast(tf.math.not_equal(h2, 0), tf.float32), tape.gradient(y, dense.trainable_variables)) x = tf.convert_to_tensor(np.random.normal(size=(4, 8)), tf.float32) mask1, mask2, gradients = f(x) self.evaluate(tf.compat.v1.initializers.global_variables()) mask1, mask2, gradients = self.evaluate([mask1, mask2, gradients]) # Make sure entries were masked and there is randomness. self.assertGreaterEqual(np.sum(mask1 == 0), 2) self.assertGreaterEqual(np.sum(mask2 == 0), 2) self.assertNotAllEqual(mask1, mask2) # Use the masks to compute exact gradients. def g(x): with tf.GradientTape() as tape: # Rescale proportional to dropout rate. h1 = (dense(x) * mask1) / 0.6 h2 = (dense(x) * mask2) / 0.6 y = tf.math.reduce_sum(h1 + h2) return tape.gradient(y, dense.trainable_variables) expected_gradients = self.evaluate(g(x)) self.assertAllClose(gradients, expected_gradients)
def __init__(self, hidden_size, num_heads, att_dropout_prob=0.0, share_kv_projections=True, initializer=None, name='fused_side_attention', **kwargs): """Init. Args: hidden_size: Size of the main input hidden dimension. This will also be the size of the main output and intermediate queries/keys/values. num_heads: Number of attention heads. att_dropout_prob: Dropout probability for attention probabilities. Must be between 0.0 and 1.0. The default of 0.0 skips dropout. share_kv_projections: If True, key and value projections will be shared between main-to-main and main-to-side components. This results in 1 key projection per layer instead of 2 (and similarly for value projections). initializer: Initializer to use for non-bias variables other than the relative embedding table and persistent memory vectors. Bias variables will be initialized to 0. name: Name of the layer. **kwargs: Forwarded to super. """ super(FusedSideAttention, self).__init__(name=name, **kwargs) self.hidden_size = hidden_size self.num_heads = num_heads self.att_dropout_prob = att_dropout_prob self.share_kv_projections = share_kv_projections self.initializer = initializer self._validate_init_parameters() def make_att_head_projection(name): return ProjectAttentionHeads(num_heads=num_heads, size_per_head=hidden_size // num_heads, use_bias=True, initializer=initializer, name=name) # TODO(urikz): Test if combining projections into one is more efficient self.main_query_projection = make_att_head_projection( 'main_query_projection') self.main_key_projection = make_att_head_projection( 'main_key_projection') self.main_value_projection = make_att_head_projection( 'main_value_projection') if self.share_kv_projections: self.side_key_projection = self.main_key_projection self.side_value_projection = self.main_value_projection else: self.side_key_projection = make_att_head_projection( 'side_key_projection') self.side_value_projection = make_att_head_projection( 'side_value_projection') if self.att_dropout_prob != 0.0: self.att_dropout = recomputing_dropout.RecomputingDropout( rate=self.att_dropout_prob) self.output_projection = _make_output_projection( output_size=self.hidden_size, name='output_projection', kernel_initializer=initializer)
def __init__(self, hidden_size, num_heads, att_dropout_prob=0.0, enable_default_side_input=False, initializer=None, top_k_attention=None, pos_embed_mode=None, pos_embed_size=None, use_one_hot_embeddings=None, name='fused_side_attention', **kwargs): """Init. Args: hidden_size: Size of the main input hidden dimension. This will also be the size of the main output and intermediate queries/keys/values. num_heads: Number of attention heads. Must be greater or equal than 0, where 0 heads means that cross attention layer will have a single attention head WITHOUT projection matrices. att_dropout_prob: Dropout probability for attention probabilities. Must be between 0.0 and 1.0. The default of 0.0 skips dropout. enable_default_side_input: Add a default side input, which acts like a no-op attention, effective allowing attention weights to sum up to something less than 1. initializer: Initializer to use for non-bias variables other than the relative embedding table and persistent memory vectors. Bias variables will be initialized to 0. top_k_attention: Whether to restrict attention to the top K items only. pos_embed_mode: Whether and how to add positional information. pos_embed_size: Max position. use_one_hot_embeddings: Whether to use one hot embeddings. name: Name of the layer. **kwargs: Forwarded to super. """ super(SideAttention, self).__init__(name=name, **kwargs) self.hidden_size = hidden_size self.num_heads = num_heads self.att_dropout_prob = att_dropout_prob self.initializer = initializer self.enable_default_side_input = enable_default_side_input self.top_k_attention = top_k_attention self.pos_embed_mode = pos_embed_mode self._validate_init_parameters() def make_att_head_projection(name): if num_heads > 0: return ProjectAttentionHeads(num_heads=num_heads, size_per_head=hidden_size // num_heads, use_bias=True, initializer=initializer, name=name) else: return None self.query_projection = make_att_head_projection('query_projection') self.key_projection = make_att_head_projection('key_projection') self.value_projection = make_att_head_projection('value_projection') if self.num_heads > 0: self.output_projection = _make_output_projection( output_size=self.hidden_size, name='output_projection', kernel_initializer=initializer) else: self.output_projection = tf.keras.layers.Layer() if self.att_dropout_prob != 0.0: self.att_dropout = recomputing_dropout.RecomputingDropout( rate=self.att_dropout_prob) if self.pos_embed_mode in [ 'absolute', 'absolute_add_ln', 'simple_relative', 'query_dot_relative' ]: if pos_embed_size is None: raise ValueError('pos_embed_size` must be not None when ' '`pos_embed_mode` is not None') if use_one_hot_embeddings is None: raise ValueError( 'use_one_hot_embeddings` must be not None when ' '`pos_embed_mode` is not None') self.pos_embed_size = pos_embed_size self.block_position_embedding = embedding.EmbeddingLookup( vocab_size=(pos_embed_size if self.pos_embed_mode == 'absolute' else 2 * pos_embed_size + 1), embedding_size=(max(self.num_heads, 1) if self.pos_embed_mode == 'simple_relative' else self.hidden_size), initializer_range=0.02, use_one_hot_lookup=use_one_hot_embeddings, name='block_position_emb_lookup') if self.pos_embed_mode == 'absolute_add_ln': self.block_position_embedding_norm = tf.keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name='block_position_emb_layer_norm') elif self.pos_embed_mode is None: self.block_position_embedding = None else: raise ValueError('Unknown position embeddings mode: ' + self.pos_embed_mode)
def test_nested_recompute_grad(self): """Tests nested usage of recompute_grad.""" dense = tf.keras.layers.Dense( 5, input_shape=(8,), bias_initializer='glorot_normal') dropout = recomputing_dropout.RecomputingDropout( 0.4, force_recomputation=True) @recompute_grad.recompute_grad def recompute_dense_dropout_tower(x): return dropout(dense(x), training=True) def make_head(): inputs = tf.keras.Input(shape=(5,)) x = tf.keras.layers.Dense( 3, activation='tanh', name='dense', bias_initializer='glorot_normal')( inputs) x = recomputing_dropout.RecomputingDropout(0.45)(x) outputs = { 'head_mask': tf.cast(tf.math.not_equal(x, 0), tf.float32), 'y': tf.reduce_sum(x), } return tf.keras.Model(inputs, outputs, name='head') head = make_head() # Nest recompute_grad inside another recompute_grad function. @recompute_grad.recompute_grad def recompute_model(x): y1 = recompute_dense_dropout_tower(x) y2 = recompute_dense_dropout_tower(x) outputs = head(y1 + y2, training=True) outputs.update({ 'tower1_mask': tf.cast(tf.math.not_equal(y1, 0), tf.float32), 'tower2_mask': tf.cast(tf.math.not_equal(y2, 0), tf.float32), }) return outputs def f(x): with tf.GradientTape() as tape: outputs = recompute_model(x) outputs['gradients'] = tape.gradient( outputs.pop('y'), dense.trainable_variables + head.trainable_variables) return outputs x = tf.convert_to_tensor(np.random.normal(size=(4, 8)), tf.float32) outputs = f(x) self.evaluate(tf.compat.v1.initializers.global_variables()) outputs = self.evaluate(outputs) # Verify gradients are correct. def g(x): with tf.GradientTape() as tape: y1 = dense(x) * outputs['tower1_mask'] / 0.6 y2 = dense(x) * outputs['tower2_mask'] / 0.6 y = tf.reduce_sum( head.get_layer('dense')(y1 + y2) * outputs['head_mask'] / 0.55) return tape.gradient(y, dense.trainable_variables + head.trainable_variables) # Increase tolerance from default of 1e-6 to reduce flakiness. self.assertAllClose( outputs['gradients'], self.evaluate(g(x)), rtol=2e-5, atol=2e-5)