Exemple #1
0
  def test_make_relative_att_ids_batch_size_2_tensor(self):
    dummy_batch = tf.ones([2, 5])

    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=3)

    expected = [
        [
            [0, 1, 2, 3, 3],  #
            [4, 0, 1, 2, 3],  #
            [5, 4, 0, 1, 2],  #
            [6, 5, 4, 0, 1],  #
            [6, 6, 5, 4, 0],  #
        ],
        [
            [0, 1, 2, 3, 3],  #
            [4, 0, 1, 2, 3],  #
            [5, 4, 0, 1, 2],  #
            [6, 5, 4, 0, 1],  #
            [6, 6, 5, 4, 0],  #
        ]
    ]
    self.assertAllEqual(
        expected,
        relative_pos_gen.make_relative_att_ids(
            seq_len=5, batch_size=tf.shape(dummy_batch)[0]))
Exemple #2
0
  def test_relative_position_generator_init_max_distance_0(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=0)

    self.assertEqual(0, relative_pos_gen.max_distance)
    self.assertEqual(False, relative_pos_gen.ignore_direction)
    self.assertEqual(1, relative_pos_gen.relative_vocab_size)
    self.assertEqual(0, relative_pos_gen.left_pad_value)
    self.assertEqual(0, relative_pos_gen.right_pad_value)
Exemple #3
0
  def test_make_relative_att_ids_invalid_arguments(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=3)

    with self.assertRaises(ValueError):
      relative_pos_gen.make_relative_att_ids(0)

    with self.assertRaises(ValueError):
      relative_pos_gen.make_relative_att_ids(seq_len=5, batch_size=0)
Exemple #4
0
  def test_relative_position_generator_init_ignore_direction(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(
        max_distance=3, ignore_direction=True)

    self.assertEqual(3, relative_pos_gen.max_distance)
    self.assertEqual(True, relative_pos_gen.ignore_direction)
    self.assertEqual(4, relative_pos_gen.relative_vocab_size)
    self.assertEqual(3, relative_pos_gen.left_pad_value)
    self.assertEqual(3, relative_pos_gen.right_pad_value)
Exemple #5
0
  def test_make_local_relative_att_ids_max_distance_0(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=0)

    expected = [[
        [0, 0, 0, 0, 0],  #
        [0, 0, 0, 0, 0],  #
    ]]
    self.assertAllEqual(
        expected,
        relative_pos_gen.make_local_relative_att_ids(seq_len=2, local_radius=2))
Exemple #6
0
  def test_make_relative_att_ids_trimming_case(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=9)

    expected = [[
        [0, 1, 2, 3, 4],  #
        [10, 0, 1, 2, 3],  #
        [11, 10, 0, 1, 2],  #
        [12, 11, 10, 0, 1],  #
        [13, 12, 11, 10, 0],  #
    ]]
    self.assertAllEqual(expected, relative_pos_gen.make_relative_att_ids(5))
Exemple #7
0
  def test_make_local_relative_att_ids_no_pad_or_trim_case(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=4)

    expected = [[
        [8, 7, 6, 5, 0, 1, 2, 3, 4],  #
        [8, 7, 6, 5, 0, 1, 2, 3, 4],  #
        [8, 7, 6, 5, 0, 1, 2, 3, 4],  #
    ]]
    self.assertAllEqual(
        expected,
        relative_pos_gen.make_local_relative_att_ids(seq_len=3, local_radius=4))
Exemple #8
0
  def test_make_local_relative_att_ids_trimming_case(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=9)

    expected = [[
        [13, 12, 11, 10, 0, 1, 2, 3, 4],  #
        [13, 12, 11, 10, 0, 1, 2, 3, 4],  #
        [13, 12, 11, 10, 0, 1, 2, 3, 4],  #
    ]]
    self.assertAllEqual(
        expected,
        relative_pos_gen.make_local_relative_att_ids(seq_len=3, local_radius=4))
Exemple #9
0
  def test_make_relative_att_ids_no_pad_or_trim_case(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=4)

    expected = [[
        [0, 1, 2, 3, 4],  #
        [5, 0, 1, 2, 3],  #
        [6, 5, 0, 1, 2],  #
        [7, 6, 5, 0, 1],  #
        [8, 7, 6, 5, 0],  #
    ]]
    self.assertAllEqual(expected, relative_pos_gen.make_relative_att_ids(5))
Exemple #10
0
  def test_make_relative_att_ids_padding_case(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=3)

    expected = [[
        [0, 1, 2, 3, 3, 3],  #
        [4, 0, 1, 2, 3, 3],  #
        [5, 4, 0, 1, 2, 3],  #
        [6, 5, 4, 0, 1, 2],  #
        [6, 6, 5, 4, 0, 1],  #
        [6, 6, 6, 5, 4, 0],  #
    ]]
    self.assertAllEqual(expected, relative_pos_gen.make_relative_att_ids(6))
Exemple #11
0
  def test_make_local_relative_att_ids_padding_case(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=3)

    expected = [[
        [6, 6, 6, 5, 4, 0, 1, 2, 3, 3, 3],  #
        [6, 6, 6, 5, 4, 0, 1, 2, 3, 3, 3],  #
        [6, 6, 6, 5, 4, 0, 1, 2, 3, 3, 3],  #
        [6, 6, 6, 5, 4, 0, 1, 2, 3, 3, 3],  #
    ]]
    self.assertAllEqual(
        expected,
        relative_pos_gen.make_local_relative_att_ids(seq_len=4, local_radius=5))
Exemple #12
0
  def test_make_relative_att_ids_padding_case_ignore_direction(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(
        max_distance=3, ignore_direction=True)

    expected = [[
        [0, 1, 2, 3, 3, 3],  #
        [1, 0, 1, 2, 3, 3],  #
        [2, 1, 0, 1, 2, 3],  #
        [3, 2, 1, 0, 1, 2],  #
        [3, 3, 2, 1, 0, 1],  #
        [3, 3, 3, 2, 1, 0],  #
    ]]
    self.assertAllEqual(expected, relative_pos_gen.make_relative_att_ids(6))
Exemple #13
0
  def test_make_local_relative_att_ids_batch_size_2(self):
    relative_pos_gen = feature_utils.RelativePositionGenerator(max_distance=3)

    expected = [
        [
            [6, 6, 5, 4, 0, 1, 2, 3, 3],  #
            [6, 6, 5, 4, 0, 1, 2, 3, 3],  #
            [6, 6, 5, 4, 0, 1, 2, 3, 3],  #
        ],
        [
            [6, 6, 5, 4, 0, 1, 2, 3, 3],  #
            [6, 6, 5, 4, 0, 1, 2, 3, 3],  #
            [6, 6, 5, 4, 0, 1, 2, 3, 3],  #
        ],
    ]
    self.assertAllEqual(
        expected,
        relative_pos_gen.make_local_relative_att_ids(
            seq_len=3, local_radius=4, batch_size=2))
Exemple #14
0
 def test_relative_position_generator_init_invalid_arguments(self):
   with self.assertRaises(ValueError):
     feature_utils.RelativePositionGenerator(max_distance=-1)
Exemple #15
0
  def __init__(self,
               config: EtcConfig,
               is_training: Optional[bool] = None,
               use_one_hot_embeddings=False,
               use_one_hot_relative_embeddings=False,
               name: Text = "etc_document_bert",
               **kwargs):
    """Constructor for `EtcModel`.

    Args:
      config: `EtcConfig` instance.
      is_training: Optional bool. True for training model, False for eval model.
        The None default will defer to the typical Keras `training` argument in
        `call` instead. When `is_training` is specified here, the `training`
        argument from `call` must not be used.
      use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
        embeddings or tf.nn.embedding_lookup() for the word embeddings.
      use_one_hot_relative_embeddings: (optional) bool. Whether to use one-hot
        word embeddings or tf.nn.embedding_lookup() for the relative position
        embeddings.
      name: (Optional) name of the layer.
      **kwargs: Forwarded to super.

    Raises:
      ValueError: The config is invalid.
    """
    super(EtcModel, self).__init__(name=name, **kwargs)

    config = copy.deepcopy(config)
    if is_training is not None and not is_training:
      config.hidden_dropout_prob = 0.0
      config.attention_probs_dropout_prob = 0.0

    self.config = config
    self.is_training = is_training
    self.use_one_hot_embeddings = use_one_hot_embeddings
    self.use_one_hot_relative_embeddings = use_one_hot_relative_embeddings

    if config.relative_vocab_size is None:
      if config.relative_pos_max_distance != 0:
        raise ValueError(
            "`relative_pos_max_distance` must be 0 when `relative_vocab_size` "
            "is None.")
    elif config.relative_vocab_size < (feature_utils.RelativePositionGenerator(
        config.relative_pos_max_distance).relative_vocab_size +
                                       _NUM_OTHER_RELATIVE_IDS):
      raise ValueError("`relative_vocab_size` ({}) too small for "
                       "`relative_pos_max_distance` ({})".format(
                           config.relative_vocab_size,
                           config.relative_pos_max_distance))
    if config.embedding_size is None:
      config.embedding_size = config.hidden_size

    self.token_embedding = etc_layers.EmbeddingLookup(
        vocab_size=config.vocab_size,
        embedding_size=config.embedding_size,
        projection_size=config.hidden_size,
        initializer_range=config.initializer_range,
        use_one_hot_lookup=use_one_hot_embeddings,
        name="token_emb_lookup")

    self.token_embedding_norm = tf.keras.layers.LayerNormalization(
        axis=-1, epsilon=1e-12, name="long_emb_layer_norm")
    self.token_embedding_dropout = tf.keras.layers.Dropout(
        rate=config.hidden_dropout_prob)

    self.segment_embedding = etc_layers.EmbeddingLookup(
        vocab_size=config.segment_vocab_size,
        embedding_size=config.hidden_size,
        initializer_range=config.initializer_range,
        use_one_hot_lookup=True,
        name="segment_emb_lookup")

    if config.max_absolute_position_embeddings != 0:
      self.position_embedding = etc_layers.EmbeddingLookup(
          vocab_size=config.max_absolute_position_embeddings,
          embedding_size=config.hidden_size,
          initializer_range=config.initializer_range,
          use_one_hot_lookup=use_one_hot_embeddings,
          name="position_emb_lookup_long")
      # We use `max_absolute_position_embeddings` for the maximum global input
      # length even though it's larger than we need. This makes it easier to
      # initialize both long and global position embedding tables with the same
      # values if desired.
      self.global_position_embedding = etc_layers.EmbeddingLookup(
          vocab_size=config.max_absolute_position_embeddings,
          embedding_size=config.hidden_size,
          initializer_range=config.initializer_range,
          use_one_hot_lookup=use_one_hot_embeddings,
          name="position_emb_lookup_global")
      # Call layers to force variable initialization.
      self.position_embedding(tf.ones([1, 1], tf.int32))
      self.global_position_embedding(tf.ones([1, 1], tf.int32))
    else:
      self.position_embedding = None
      self.global_position_embedding = None

    # We use the same embedding table for global tokens to make it easy to place
    # WordPieces in the global memory for finetuning tasks downstream.
    self.global_token_embedding = self.token_embedding
    self.global_token_embedding_norm = tf.keras.layers.LayerNormalization(
        axis=-1, epsilon=1e-12, name="global_emb_layer_norm")
    self.global_token_embedding_dropout = tf.keras.layers.Dropout(
        rate=config.hidden_dropout_prob)

    self.global_local_transformer = etc_layers.GlobalLocalTransformerLayers(
        long_hidden_size=config.hidden_size,
        global_hidden_size=config.hidden_size,
        num_hidden_layers=config.num_hidden_layers,
        num_attention_heads=config.num_attention_heads,
        local_radius=config.local_radius,
        att_size_per_head=config.att_size_per_head,
        long_intermediate_size=config.intermediate_size,
        global_intermediate_size=config.intermediate_size,
        hidden_act=tensor_utils.get_activation(config.hidden_act),
        hidden_dropout_prob=config.hidden_dropout_prob,
        attention_probs_dropout_prob=config.attention_probs_dropout_prob,
        initializer_range=config.initializer_range,
        relative_vocab_size=config.relative_vocab_size,
        share_feed_forward_params=config.share_feed_forward_params,
        share_kv_projections=config.share_kv_projections,
        share_qkv_projections=config.share_qkv_projections,
        share_att_output_projection=config.share_att_output_projection,
        use_pre_activation_order=config.use_pre_activation_order,
        use_one_hot_lookup=use_one_hot_relative_embeddings,
        grad_checkpointing_period=config.grad_checkpointing_period)
Exemple #16
0
def make_global_local_transformer_side_inputs(
    long_paragraph_breakpoints: tf.Tensor,
    long_paragraph_ids: tf.Tensor,
    long_sentence_ids: tf.Tensor,
    global_paragraph_breakpoints: tf.Tensor,
    local_radius: int,
    relative_pos_max_distance: int,
    use_hard_g2l_mask: bool = False,
    ignore_hard_g2l_mask: tf.Tensor = None,
    use_hard_l2g_mask: bool = False,
    ignore_hard_l2g_mask: tf.Tensor = None,
    flat_sequence: bool = False,
    l2g_linked_ids: Optional[tf.Tensor] = None,
    name: Optional[Text] = None
) -> input_utils.GlobalLocalTransformerSideInputs:
    """Makes attention masks and relative ids for l2l, l2g, g2g, g2l for QA tasks.

  When `use_hard_g2l_mask=True` and `use_hard_l2g_mask=False`, the resulting
  attention pattern is similar to Figure 3b of the paper for representing
  a set of (unordered) contexts ("paragraphs" here), except instead of
  defining a new relative position label between a global paragraph token and
  its global sentence tokens, we just place each global paragraph token as
  the first token before subsequent global sentence tokens belonging to it.

  Note: This function assumes that we don't pack multiple examples into a single
  example, which is only done for pre-training.

  See `GlobalLocalTransformerLayers.call()` in `layers/transformer.py` for a
  description of the 8 side inputs.

  Args:
    long_paragraph_breakpoints: <int32>[batch_size, global_seq_len] Tensor of
      `0`s and `1`s indicating paragraph boundaries in the long input.
    long_paragraph_ids: <int32>[batch_size, long_seq_len] Tensor of ids
      indicating the paragraph each token belongs to.
    long_sentence_ids: <int32>[batch_size, long_seq_len] Tensor of ids
      indicating which sentence each token belongs to.
    global_paragraph_breakpoints: <int32>[batch_size, global_seq_len] Tensor of
      of `0`s and `1`s indicating paragraph boundaries in the global input.
    local_radius: How many tokens to the left/right for input tokens to locally
      self-attend to. For example, a value of 1 would allow each token to only
      attend to 1 token to the left and 1 token to the right of it.
    relative_pos_max_distance: Maximum distance to use for relative position
      representations. All larger distances will be clipped to this value. Use 0
      to skip relative position representations entirely.
    use_hard_g2l_mask: If True, global tokens only attend to tokens of the
      corresponding sentences in the long input. If False, global tokens attend
      to all sentences within the corresponding global example.
    ignore_hard_g2l_mask: <int32>[batch_size, global_seq_len] Tensor of `0`s and
      `1`s indicating the indices in the global input which should ignore the
      `use_hard_g2l_mask`. `1` is for ignoring the hard mask and these tokens
      essentially attend to everything (except for padding tokens) in the long
      input. This can be useful to force some tokens (e.g, CLS) to attend to
      everything in the long input even though they don't necessarily map to
      anything in the long input via sentence / paragraph ids etc. This tensor
      will be applicable only when `use_hard_g2l` is enabled.
    use_hard_l2g_mask: If True, long tokens only attend to tokens of the
      corresponding global tokens. If False, long tokens attend to all the
      global tokens within the corresponding global example.
    ignore_hard_l2g_mask: <int32>[batch_size, long_seq_len] Tensor of `0`s and
      `1`s indicating the indices in the long input which should ignore the
      `use_hard_l2g_mask`. `1` is for ignoring the hard mask and these tokens
      essentially attend to everything (except for padding tokens) in the global
      input. This can be useful to force some tokens (e.g, query tokens) to
      attend to everything in the global input even though they don't
      necessarily map to anything in the global input via sentence / paragraph
      ids etc. This tensor will be applicable only when `use_hard_l2g` is
      enabled.
    flat_sequence: If True, the attention masks / relative attention ids would
      be computing assuming the default ETC setting where there is not any
      structure (except for having the notion of a "sentence").
    l2g_linked_ids: <int32>[batch_size, long_seq_len] Tensor specifying the long
      tokens which should be linked to the global tokens. If the input is [[-1,
      -1, 0, 1, 1, -1]], then 2nd long token would be linked to 0-th global
      token and 3rd, 4-th long tokens woulbe linked to the 1st global token.
    name: A name for the operation (optional).

  Returns:
    A `GlobalLocalTransformerSideInputs` with all relevant tensors set.
  """
    with tf.name_scope(name or 'make_global_local_transformer_side_inputs'):

        long_input_mask = tf.minimum(
            tf.cumsum(long_paragraph_breakpoints, axis=-1, reverse=True), 1)
        global_input_mask = tf.minimum(
            tf.cumsum(global_paragraph_breakpoints, axis=-1, reverse=True), 1)

        if flat_sequence:
            # Here we don't use any structure in the input i.e it falls back to
            # the default ETC setting where:
            # a) everything in the long can attend to everything in the global and
            #    vice-versa.
            # b) everything in global attends to everything in global.
            # c) everything in long can attend to everything in long that is within
            #    the local radius
            #
            # Note that there is a small caveat here: The paragraph / cls level tokens
            # in the global input would be orphaned (i.e they wouldn't be linked to
            # anything in the long), but that should be probably
            # okay as they still attend to everything in the global.
            #
            # We don't have any packing here. So we need to construct
            # long/global breakpoints to indicate there's only one example.
            # The structure of these breakpoints should be as follows:
            # [0, 0, .....,1, 0, 0, 0] i.e there should be a single `1` just before
            # the padding begins, rest of the tokens should be `0`.
            return (input_utils.
                    make_global_local_transformer_side_inputs_from_example_ids(
                        long_example_ids=long_input_mask,
                        global_example_ids=global_input_mask,
                        sentence_ids=long_sentence_ids,
                        local_radius=local_radius,
                        relative_pos_max_distance=relative_pos_max_distance,
                        use_hard_g2l_mask=use_hard_g2l_mask,
                        use_hard_l2g_mask=use_hard_l2g_mask))

        # Make paragraphs not attend to other paragraphs in the long input.
        long_paragraph_breakpoints = tf.convert_to_tensor(
            long_paragraph_breakpoints)
        long_paragraph_breakpoint_segments = tf.cumsum(
            long_paragraph_breakpoints, axis=-1, reverse=True)

        l2l_att_mask = feature_utils.make_local_segmented_att_mask(
            long_paragraph_breakpoint_segments, local_radius)

        global_paragraph_breakpoints = tf.convert_to_tensor(
            global_paragraph_breakpoints)
        global_paragraph_breakpoint_segments = tf.cumsum(
            global_paragraph_breakpoints, axis=-1, reverse=True)

        # For g2l, g2g and l2g, we can have everything attend everything else.
        # So we can have attention tokens as all `1`s and account for padding via
        # a mask.
        def _make_input_mask_from_breakpoints(
                breakpoint_segments: tf.Tensor) -> tf.Tensor:
            return tf.minimum(tf.cast(1, dtype=breakpoint_segments.dtype),
                              breakpoint_segments)

        long_attention_tokens = _make_input_mask_from_breakpoints(
            long_paragraph_breakpoint_segments)

        # Ignore the padding tokens.
        global_attention_tokens = _make_input_mask_from_breakpoints(
            global_paragraph_breakpoint_segments)

        g2g_att_mask = feature_utils.make_segmented_att_mask(
            global_attention_tokens)
        l2g_att_mask = tf.cast(
            tf.equal(long_attention_tokens[:, :, tf.newaxis],
                     global_attention_tokens[:, tf.newaxis, :]), tf.int32)
        g2l_att_mask = tf.transpose(l2g_att_mask, perm=[0, 2, 1])

        long_seq_len = long_paragraph_breakpoints.shape.as_list()[1]
        assert long_seq_len is not None

        global_seq_len = global_paragraph_breakpoints.shape.as_list()[1]
        assert global_seq_len is not None

        batch_size = tf.shape(long_paragraph_breakpoints)[0]
        assert batch_size is not None

        global_range = tf.range(global_seq_len, dtype=long_sentence_ids.dtype)
        long_ones = tf.ones_like(long_sentence_ids)
        global_ones = tf.ones_like(global_paragraph_breakpoints)

        if use_hard_g2l_mask:
            if ignore_hard_g2l_mask is None:
                ignore_hard_g2l_mask = tf.zeros_like(
                    global_paragraph_breakpoints)
            else:
                ignore_hard_g2l_mask = tf.convert_to_tensor(
                    ignore_hard_g2l_mask)

            # Have each global token attend to just one sentence instead of having
            # it attend to all the sentences within a global example.
            sentence_hard_g2l_att_mask = tf.equal(
                global_range[tf.newaxis, :, tf.newaxis],
                long_sentence_ids[:, tf.newaxis, :])

            # Also have paragraph global tokens attend to the corresponding long
            # paragraphs.
            paragraph_hard_g2l_att_mask = tf.equal(
                global_range[tf.newaxis, :, tf.newaxis],
                long_paragraph_ids[:, tf.newaxis, :])

            ignore_hard_g2l_att_mask = tf.equal(
                ignore_hard_g2l_mask[:, :, tf.newaxis],
                long_ones[:, tf.newaxis, :])

            # It's possible that certain global tokens, although linked to a long
            # sentence, might still be present in `ignore_hard_g2l_mask`. Such tokens
            # should also attend to everything in the long.
            hard_g2l_att_mask = tf.math.logical_or(
                tf.math.logical_or(sentence_hard_g2l_att_mask,
                                   paragraph_hard_g2l_att_mask),
                ignore_hard_g2l_att_mask)

            hard_g2l_att_mask = tf.cast(hard_g2l_att_mask, dtype=tf.int32)
            g2l_att_mask *= hard_g2l_att_mask

        if use_hard_l2g_mask:
            if ignore_hard_l2g_mask is None:
                ignore_hard_l2g_mask = tf.zeros_like(long_sentence_ids)
            else:
                ignore_hard_l2g_mask = tf.convert_to_tensor(
                    ignore_hard_l2g_mask)

            # Have each long token attend to just the corresponding global token
            # instead of having it attend to all the global tokens within a
            # global example.
            sentence_hard_l2g_att_mask = tf.equal(
                long_sentence_ids[:, :, tf.newaxis],
                global_range[tf.newaxis, tf.newaxis, :])

            # Also have paragraph global tokens attend to the corresponding long
            # paragraphs.
            paragraph_hard_l2g_att_mask = tf.equal(
                long_paragraph_ids[:, :, tf.newaxis],
                global_range[tf.newaxis, tf.newaxis, :])

            ignore_hard_l2g_att_mask = tf.equal(
                ignore_hard_l2g_mask[:, :, tf.newaxis],
                global_ones[:, tf.newaxis, :])

            # It's possible that certain long tokens, although linked to global tokens
            # might still be present in `ignore_hard_l2g_mask`. Such tokens
            # should also attend to everything in the global.
            hard_l2g_att_mask = tf.math.logical_or(
                tf.math.logical_or(sentence_hard_l2g_att_mask,
                                   paragraph_hard_l2g_att_mask),
                ignore_hard_l2g_att_mask)

            hard_l2g_att_mask = tf.cast(hard_l2g_att_mask, dtype=tf.int32)
            l2g_att_mask *= hard_l2g_att_mask

        l2l_relative_att_ids = None
        g2g_relative_att_ids = None
        l2g_relative_att_ids = None
        g2l_relative_att_ids = None

        if relative_pos_max_distance > 0:

            relative_pos_generator = feature_utils.RelativePositionGenerator(
                relative_pos_max_distance)

            l2l_relative_att_ids = relative_pos_generator.make_local_relative_att_ids(
                seq_len=long_seq_len,
                local_radius=local_radius,
                batch_size=batch_size)

            sentence_l2g_relative_att_ids = tf.equal(
                long_sentence_ids[:, :, tf.newaxis],
                global_range[tf.newaxis, tf.newaxis, :])

            # Add relative att ids for global paragraph level tokens.
            paragraph_l2g_relative_att_ids = tf.equal(
                global_range[tf.newaxis, tf.newaxis, :],
                long_paragraph_ids[:, :, tf.newaxis])

            if l2g_linked_ids is None:
                l2g_linked_relative_att_ids = tf.zeros_like(
                    paragraph_l2g_relative_att_ids)
            else:
                l2g_linked_ids = tf.convert_to_tensor(l2g_linked_ids)
                l2g_linked_relative_att_ids = tf.equal(
                    global_range[tf.newaxis, tf.newaxis, :],
                    l2g_linked_ids[:, :, tf.newaxis])

            l2g_relative_att_ids = tf.cast(tf.math.logical_or(
                l2g_linked_relative_att_ids,
                tf.math.logical_or(sentence_l2g_relative_att_ids,
                                   paragraph_l2g_relative_att_ids)),
                                           dtype=tf.int32)

            g2l_relative_att_ids = tf.transpose(l2g_relative_att_ids,
                                                perm=[0, 2, 1])

            # For fused attention, l2l and l2g share the same relative vocabulary, as
            # do g2g and g2l, so we add an offset for l2g and g2l so their original
            # 0/1 ids don't collide with l2l and g2g relative position ids.
            l2g_relative_att_ids += relative_pos_generator.relative_vocab_size
            g2l_relative_att_ids += relative_pos_generator.relative_vocab_size

            g2g_relative_att_ids = relative_pos_generator.make_relative_att_ids(
                seq_len=global_seq_len, batch_size=batch_size)

            # We used up 2 ids to account for the collision in fused attention as
            # mentioned above. Hence the +2.
            g2g_max_rel_id = relative_pos_generator.relative_vocab_size + 2
            g2g_relative_att_ids = (
                feature_utils.overwrite_relative_att_ids_outside_segments(
                    rel_att_ids=g2g_relative_att_ids,
                    segment_ids=global_paragraph_breakpoint_segments,
                    overwrite_value=g2g_max_rel_id))

        return input_utils.GlobalLocalTransformerSideInputs(
            l2l_att_mask=l2l_att_mask,
            g2g_att_mask=g2g_att_mask,
            l2g_att_mask=l2g_att_mask,
            g2l_att_mask=g2l_att_mask,
            l2l_relative_att_ids=l2l_relative_att_ids,
            g2g_relative_att_ids=g2g_relative_att_ids,
            l2g_relative_att_ids=l2g_relative_att_ids,
            g2l_relative_att_ids=g2l_relative_att_ids)
def make_global_local_transformer_side_inputs_from_example_ids(
        long_example_ids: tf.Tensor,
        global_example_ids: tf.Tensor,
        sentence_ids: tf.Tensor,
        local_radius: int,
        relative_pos_max_distance: int,
        use_hard_g2l_mask: bool = False,
        use_hard_l2g_mask: bool = False,
        name: Optional[Text] = None) -> GlobalLocalTransformerSideInputs:
    """Makes side input tensors based on the given example and sentence ids.

  When packing examples (e.g. for pre-training), each example must have a
  unique id for `long_example_ids`/`global_example_ids`, and padding must
  also have a unique id distinct from all the example ids.

  When not packing examples, there will simply be two unique ids: one for
  example tokens, and another for padding.  Note that in this case, the classic
  BERT `input_mask` is a valid special case of `long_example_ids`.

  The other arguments have the same interpretation as in
  `make_global_local_transformer_side_inputs`.

  Args:
    long_example_ids: <int32>[batch_size, long_seq_len] Tensor of example ids of
      different packed examples.
    global_example_ids: <int32>[batch_size, global_seq_len] Tensor of example
      ids of different packed examples.
    sentence_ids: <int32>[batch_size, long_seq_len] Tensor of ids indicating
      which sentence each token belongs to. For this dataset, "sentence" refers
      to real natural language sentence, not a BERT "sentence" from the "next
      sentence prediction" task.
    local_radius: How many tokens to the left/right for input tokens to locally
      self-attend to. For example, a value of 1 would allow each token to only
      attend to 1 token to the left and 1 token to the right of it.
    relative_pos_max_distance: Maximum distance to use for relative position
      representations. All larger distances will be clipped to this value. Use 0
      to skip relative position representations entirely.
    use_hard_g2l_mask: If True, global tokens only attend to tokens of the
      corresponding sentences in the long input. If False, global tokens attend
      to all sentences within the corresponding global example.
    use_hard_l2g_mask: If True, long tokens only attend to tokens of the
      corresponding global tokens. If False, long tokens attend to all the
      global tokens within the corresponding global example.
    name: A name for the operation (optional).

  Returns:
    A `GlobalLocalTransformerSideInputs` with all relevant tensors set.
  """
    with tf.name_scope(name or 'make_global_local_transformer_side_inputs'):
        long_example_ids = tf.convert_to_tensor(long_example_ids)
        global_example_ids = tf.convert_to_tensor(global_example_ids)
        sentence_ids = tf.convert_to_tensor(sentence_ids)

        long_seq_len = tensor_utils.get_shape_list(long_example_ids)[1]
        global_seq_len = tensor_utils.get_shape_list(global_example_ids)[1]

        l2l_att_mask = feature_utils.make_local_segmented_att_mask(
            long_example_ids, local_radius)
        g2g_att_mask = feature_utils.make_segmented_att_mask(
            global_example_ids)

        l2g_att_mask = tf.cast(
            tf.equal(long_example_ids[:, :, tf.newaxis],
                     global_example_ids[:, tf.newaxis, :]), tf.int32)
        g2l_att_mask = tf.transpose(l2g_att_mask, perm=[0, 2, 1])

        if use_hard_g2l_mask:
            # Have each global token attend to just one sentence instead of having
            # it attend to all the sentences within a global example.
            global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype)
            hard_g2l_att_mask = tf.cast(
                tf.equal(global_range[tf.newaxis, :, tf.newaxis],
                         sentence_ids[:, tf.newaxis, :]), tf.int32)
            g2l_att_mask *= hard_g2l_att_mask

        if use_hard_l2g_mask:
            # Have each long token attend to just the corresponding global token
            # instead of having it attend to all the global tokens within a
            # global example.
            global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype)
            hard_l2g_att_mask = tf.cast(
                tf.equal(sentence_ids[:, :, tf.newaxis],
                         global_range[tf.newaxis, tf.newaxis, :]), tf.int32)
            l2g_att_mask *= hard_l2g_att_mask

        batch_size = tf.shape(long_example_ids)[0]

        l2l_relative_att_ids = None
        g2g_relative_att_ids = None
        l2g_relative_att_ids = None
        g2l_relative_att_ids = None

        if relative_pos_max_distance > 0:
            relative_pos_generator = feature_utils.RelativePositionGenerator(
                relative_pos_max_distance)
            l2l_relative_att_ids = relative_pos_generator.make_local_relative_att_ids(
                seq_len=long_seq_len,
                local_radius=local_radius,
                batch_size=batch_size)
            g2g_relative_att_ids = relative_pos_generator.make_relative_att_ids(
                seq_len=global_seq_len, batch_size=batch_size)
            global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype)
            l2g_relative_att_ids = tf.cast(
                tf.equal(sentence_ids[:, :, tf.newaxis],
                         global_range[tf.newaxis, tf.newaxis, :]), tf.int32)
            g2l_relative_att_ids = tf.transpose(l2g_relative_att_ids,
                                                perm=[0, 2, 1])

            # For fused attention, l2l and l2g share the same relative vocabulary, as
            # do g2g and g2l, so we add an offset for l2g and g2l so their original
            # 0/1 ids don't collide with l2l and g2g relative position ids.
            l2g_relative_att_ids += relative_pos_generator.relative_vocab_size
            g2l_relative_att_ids += relative_pos_generator.relative_vocab_size

        return GlobalLocalTransformerSideInputs(
            l2l_att_mask=l2l_att_mask,
            g2g_att_mask=g2g_att_mask,
            l2g_att_mask=l2g_att_mask,
            g2l_att_mask=g2l_att_mask,
            l2l_relative_att_ids=l2l_relative_att_ids,
            g2g_relative_att_ids=g2g_relative_att_ids,
            l2g_relative_att_ids=l2g_relative_att_ids,
            g2l_relative_att_ids=g2l_relative_att_ids)