def call(self, inputs: Sequence[tf.Tensor], recompute_grad=False, **kwargs): def f(x): if recompute_grad_lib.get_recompute_context() is None: generator = tf.random.experimental.get_global_generator() recompute_grad_seed = tf.stack( (generator.uniform_full_int([], tf.int32, name='seed'), 0)) else: recompute_grad_seed = tf.stack( (recompute_grad_lib.get_recompute_context().seed, 0)) seeds = tf.random.stateless_uniform([len(self._units), 2], recompute_grad_seed, minval=-2**31, maxval=2**31 - 1, dtype=tf.int32, name='dropout_seeds') for i, (kernel, bias) in enumerate(zip(self._kernels, self._biases)): x = tf.nn.tanh(tf.matmul(x, kernel) + bias) x = stateless_dropout_lib.stateless_dropout( x, self._rate, seeds[i]) return x if recompute_grad: f = recompute_grad_lib.recompute_grad(f) return f(inputs)
def test_seed(self): """Tests that a seed can be provided to recompute_grad.""" seed = tf.constant(2020, dtype=tf.int32) def _make_model(): inputs = tf.keras.Input((4, )) outputs = tf.keras.layers.Dense(10)(inputs) return tf.keras.Model(inputs, outputs) model = _make_model() if not tf.executing_eagerly(): self.evaluate(tf.compat.v1.initializers.global_variables()) # Set up functions to take gradients with respect to variables. def f(x, seed=np.array(1, dtype=np.int32)): if recompute_grad_lib.get_recompute_context() is not None: seed = recompute_grad_lib.get_recompute_context().seed return stateless_dropout_lib.stateless_dropout(model(x), rate=0.5, seed=tf.stack( [seed, 0])) f_recompute = recompute_grad_lib.recompute_grad(f, seed=seed) # Compute gradients and compare them. x = tf.ones((2, 4)) gradients = self.evaluate(_compute_gradients(lambda x: f(x, seed), x)) recomputed_gradients = self.evaluate(_compute_gradients( f_recompute, x)) for gradient, recomputed_gradient in zip(gradients, recomputed_gradients): self.assertAllClose(gradient, recomputed_gradient)
def call(self, long_input: tf.Tensor, global_input: tf.Tensor, l2l_att_mask: Optional[tf.Tensor] = None, g2g_att_mask: Optional[tf.Tensor] = None, l2g_att_mask: Optional[tf.Tensor] = None, g2l_att_mask: Optional[tf.Tensor] = None, l2l_relative_att_ids: Optional[tf.Tensor] = None, g2g_relative_att_ids: Optional[tf.Tensor] = None, l2g_relative_att_ids: Optional[tf.Tensor] = None, g2l_relative_att_ids: Optional[tf.Tensor] = None, att_implementation: Text = 'auto', training=None) -> List[tf.Tensor]: """Calls the layer. We use abbreviations like "l2g" to mean "long-to-global". Args: long_input: <float32>[batch_size, long_seq_len, long_hidden_size]. global_input: <float32>[batch_size, global_seq_len, global_hidden_size]. l2l_att_mask: <int32>[batch_size, long_seq_len, 2*local_radius + 1] long-to-long attention mask for local attention. Should have only 0 and 1 values, with 0 for entries that should be masked and 1 otherwise. Leave as None to allow all long elements to attend to all other long elements within the local radius. g2g_att_mask: <int32>[batch_size, global_seq_len, global_seq_len] global-to-global attention mask. Should have only 0 and 1 values, with 0 for entries that should be masked and 1 otherwise. Leave as None to allow all global elements to attend to all other global elements within each example. l2g_att_mask: <int32>[batch_size, long_seq_len, global_seq_len] long-to-global attention mask. Should have only 0 and 1 values, with 0 for entries that should be masked and 1 otherwise. Leave as None to allow all long elements to attend to all global elements within each example. g2l_att_mask: <int32>[batch_size, global_seq_len, long_seq_len] global-to-long attention mask. Should have only 0 and 1 values, with 0 for entries that should be masked and 1 otherwise. Leave as None to allow all global elements to attend to all long elements within each example. l2l_relative_att_ids: <int32>[batch_size, long_seq_len, 2*local_radius+1] long-to-long relative local self-attention ids. Leave as None to skip the relative portion of l2l attention. g2g_relative_att_ids: <int32>[batch_size, global_seq_len, global_seq_len] global-to-global relative attention ids. Leave as None to skip the relative portion of g2g attention. l2g_relative_att_ids: <int32>[batch_size, long_seq_len, global_seq_len] long-to-global relative attention ids. Leave as None to skip the relative portion of l2g attention. g2l_relative_att_ids: <int32>[batch_size, global_seq_len, long_seq_len] global-to-long relative attention ids. Leave as None to skip the relative portion of g2l attention. att_implementation: String representing which internal attention implementation to use. Valid values include 'auto' (the default), 'sparse', and 'full'. 'sparse' is preferred for sequences longer than about 1k tokens, but 'full' may be faster for sequences shorter than this. 'auto' attempts to automatically decide when to use full attention. See `QkvRelativeLocalAttention` for more details. training: For Keras, optional boolean scalar tensor or Python boolean indicating whether the call is meant for training or inference. Returns: A list of Tensors, [long_output, global_output]: long_output: <float32>[batch_size, long_seq_len, long_hidden_size] global_output: <float32>[batch_size, global_seq_len, global_hidden_size] """ long_output = long_input global_output = global_input def make_layer_fn(index: int): """Makes a function that runs the entire `index` layer.""" def layer_fn(long_input, global_input): """A function for an entire layer.""" long_output = long_input global_output = global_input long_output, global_output = self.fused_att_layers[index]( [long_output, global_output], l2l_att_mask=l2l_att_mask, g2g_att_mask=g2g_att_mask, l2g_att_mask=l2g_att_mask, g2l_att_mask=g2l_att_mask, l2l_relative_att_ids=l2l_relative_att_ids, g2g_relative_att_ids=g2g_relative_att_ids, l2g_relative_att_ids=l2g_relative_att_ids, g2l_relative_att_ids=g2l_relative_att_ids, att_implementation=att_implementation, training=training) # Long and global feed-forward long_output = self.long_feed_forward_layers[index]( long_output, training=training) global_output = self.global_feed_forward_layers[index]( global_output, training=training) return (long_output, global_output) return layer_fn # If `grad_checkpointing_period` is 0 or greater than or equal to the # number of layers, no checkpointing will be used. stride = (self.num_hidden_layers if self.grad_checkpointing_period <= 0 else min( self.grad_checkpointing_period, self.num_hidden_layers)) # Split layers into chains of size `stride`. Put remainder at the beginning. for split in range(stride - (-self.num_hidden_layers % stride), self.num_hidden_layers + 1, stride): # Chain layers together with max length `stride`. layer_fn = functools.partial( functools.reduce, lambda outputs, f: f(*outputs), list(map(make_layer_fn, range(max(0, split - stride), split)))) # Destructure arguments for compatibility with `recompute_grad`. layer_fn = functools.partial(lambda f, *args: f(args), layer_fn) # Skip the last block. Store activations for gradient computation. if split < self.num_hidden_layers: layer_fn = recompute_grad_lib.recompute_grad(layer_fn) long_output, global_output = layer_fn(long_output, global_output) return [long_output, global_output]