def call(self,
             inputs: Sequence[tf.Tensor],
             recompute_grad=False,
             **kwargs):
        def f(x):
            if recompute_grad_lib.get_recompute_context() is None:
                generator = tf.random.experimental.get_global_generator()
                recompute_grad_seed = tf.stack(
                    (generator.uniform_full_int([], tf.int32, name='seed'), 0))
            else:
                recompute_grad_seed = tf.stack(
                    (recompute_grad_lib.get_recompute_context().seed, 0))
            seeds = tf.random.stateless_uniform([len(self._units), 2],
                                                recompute_grad_seed,
                                                minval=-2**31,
                                                maxval=2**31 - 1,
                                                dtype=tf.int32,
                                                name='dropout_seeds')
            for i, (kernel, bias) in enumerate(zip(self._kernels,
                                                   self._biases)):
                x = tf.nn.tanh(tf.matmul(x, kernel) + bias)
                x = stateless_dropout_lib.stateless_dropout(
                    x, self._rate, seeds[i])
            return x

        if recompute_grad:
            f = recompute_grad_lib.recompute_grad(f)
        return f(inputs)
    def test_seed(self):
        """Tests that a seed can be provided to recompute_grad."""
        seed = tf.constant(2020, dtype=tf.int32)

        def _make_model():
            inputs = tf.keras.Input((4, ))
            outputs = tf.keras.layers.Dense(10)(inputs)
            return tf.keras.Model(inputs, outputs)

        model = _make_model()
        if not tf.executing_eagerly():
            self.evaluate(tf.compat.v1.initializers.global_variables())
        # Set up functions to take gradients with respect to variables.
        def f(x, seed=np.array(1, dtype=np.int32)):
            if recompute_grad_lib.get_recompute_context() is not None:
                seed = recompute_grad_lib.get_recompute_context().seed
            return stateless_dropout_lib.stateless_dropout(model(x),
                                                           rate=0.5,
                                                           seed=tf.stack(
                                                               [seed, 0]))

        f_recompute = recompute_grad_lib.recompute_grad(f, seed=seed)
        # Compute gradients and compare them.
        x = tf.ones((2, 4))
        gradients = self.evaluate(_compute_gradients(lambda x: f(x, seed), x))
        recomputed_gradients = self.evaluate(_compute_gradients(
            f_recompute, x))
        for gradient, recomputed_gradient in zip(gradients,
                                                 recomputed_gradients):
            self.assertAllClose(gradient, recomputed_gradient)
Beispiel #3
0
    def call(self,
             long_input: tf.Tensor,
             global_input: tf.Tensor,
             l2l_att_mask: Optional[tf.Tensor] = None,
             g2g_att_mask: Optional[tf.Tensor] = None,
             l2g_att_mask: Optional[tf.Tensor] = None,
             g2l_att_mask: Optional[tf.Tensor] = None,
             l2l_relative_att_ids: Optional[tf.Tensor] = None,
             g2g_relative_att_ids: Optional[tf.Tensor] = None,
             l2g_relative_att_ids: Optional[tf.Tensor] = None,
             g2l_relative_att_ids: Optional[tf.Tensor] = None,
             att_implementation: Text = 'auto',
             training=None) -> List[tf.Tensor]:
        """Calls the layer.

    We use abbreviations like "l2g" to mean "long-to-global".

    Args:
      long_input: <float32>[batch_size, long_seq_len, long_hidden_size].
      global_input: <float32>[batch_size, global_seq_len, global_hidden_size].
      l2l_att_mask: <int32>[batch_size, long_seq_len,  2*local_radius + 1]
        long-to-long attention mask for local attention. Should have only 0 and
        1 values, with 0 for entries that should be masked and 1 otherwise.
        Leave as None to allow all long elements to attend to all other long
        elements within the local radius.
      g2g_att_mask: <int32>[batch_size, global_seq_len, global_seq_len]
        global-to-global attention mask. Should have only 0 and 1 values, with 0
        for entries that should be masked and 1 otherwise. Leave as None to
        allow all global elements to attend to all other global elements within
        each example.
      l2g_att_mask: <int32>[batch_size, long_seq_len, global_seq_len]
        long-to-global attention mask. Should have only 0 and 1 values, with 0
        for entries that should be masked and 1 otherwise. Leave as None to
        allow all long elements to attend to all global elements within each
        example.
      g2l_att_mask: <int32>[batch_size, global_seq_len, long_seq_len]
        global-to-long attention mask. Should have only 0 and 1 values, with 0
        for entries that should be masked and 1 otherwise. Leave as None to
        allow all global elements to attend to all long elements within each
        example.
      l2l_relative_att_ids: <int32>[batch_size, long_seq_len, 2*local_radius+1]
        long-to-long relative local self-attention ids. Leave as None to skip
        the relative portion of l2l attention.
      g2g_relative_att_ids: <int32>[batch_size, global_seq_len, global_seq_len]
        global-to-global relative attention ids. Leave as None to skip the
        relative portion of g2g attention.
      l2g_relative_att_ids: <int32>[batch_size, long_seq_len, global_seq_len]
        long-to-global relative attention ids. Leave as None to skip the
        relative portion of l2g attention.
      g2l_relative_att_ids: <int32>[batch_size, global_seq_len, long_seq_len]
        global-to-long relative attention ids. Leave as None to skip the
        relative portion of g2l attention.
      att_implementation: String representing which internal attention
        implementation to use. Valid values include 'auto' (the default),
        'sparse', and 'full'. 'sparse' is preferred for sequences longer than
        about 1k tokens, but 'full' may be faster for sequences shorter than
        this. 'auto' attempts to automatically decide when to use full
        attention. See `QkvRelativeLocalAttention` for more details.
      training: For Keras, optional boolean scalar tensor or Python boolean
        indicating whether the call is meant for training or inference.

    Returns:
      A list of Tensors, [long_output, global_output]:
        long_output: <float32>[batch_size, long_seq_len, long_hidden_size]
        global_output: <float32>[batch_size, global_seq_len, global_hidden_size]
    """
        long_output = long_input
        global_output = global_input

        def make_layer_fn(index: int):
            """Makes a function that runs the entire `index` layer."""
            def layer_fn(long_input, global_input):
                """A function for an entire layer."""
                long_output = long_input
                global_output = global_input

                long_output, global_output = self.fused_att_layers[index](
                    [long_output, global_output],
                    l2l_att_mask=l2l_att_mask,
                    g2g_att_mask=g2g_att_mask,
                    l2g_att_mask=l2g_att_mask,
                    g2l_att_mask=g2l_att_mask,
                    l2l_relative_att_ids=l2l_relative_att_ids,
                    g2g_relative_att_ids=g2g_relative_att_ids,
                    l2g_relative_att_ids=l2g_relative_att_ids,
                    g2l_relative_att_ids=g2l_relative_att_ids,
                    att_implementation=att_implementation,
                    training=training)

                # Long and global feed-forward
                long_output = self.long_feed_forward_layers[index](
                    long_output, training=training)
                global_output = self.global_feed_forward_layers[index](
                    global_output, training=training)

                return (long_output, global_output)

            return layer_fn

        # If `grad_checkpointing_period` is 0 or greater than or equal to the
        # number of layers, no checkpointing will be used.
        stride = (self.num_hidden_layers
                  if self.grad_checkpointing_period <= 0 else min(
                      self.grad_checkpointing_period, self.num_hidden_layers))
        # Split layers into chains of size `stride`. Put remainder at the beginning.
        for split in range(stride - (-self.num_hidden_layers % stride),
                           self.num_hidden_layers + 1, stride):
            # Chain layers together with max length `stride`.
            layer_fn = functools.partial(
                functools.reduce, lambda outputs, f: f(*outputs),
                list(map(make_layer_fn, range(max(0, split - stride), split))))
            # Destructure arguments for compatibility with `recompute_grad`.
            layer_fn = functools.partial(lambda f, *args: f(args), layer_fn)
            # Skip the last block. Store activations for gradient computation.
            if split < self.num_hidden_layers:
                layer_fn = recompute_grad_lib.recompute_grad(layer_fn)
            long_output, global_output = layer_fn(long_output, global_output)

        return [long_output, global_output]