def test_split_into_blocks_3d(self):
        # shape: [2, 4, 2]
        tensor = tf.constant([
            [[1, -1], [2, -2], [3, -3], [4, -4]],  #
            [[11, 21], [12, 22], [13, 23], [14, 24]]
        ])

        self.assertAllEqual(
            [
                [
                    [[1, -1], [2, -2]],  #
                    [[3, -3], [4, -4]],  #
                ],
                [
                    [[11, 21], [12, 22]],  #
                    [[13, 23], [14, 24]],  #
                ]
            ],
            tensor_utils.split_into_blocks(tensor, block_len=2, axis=-2))

        self.assertAllEqual(
            [
                [
                    [[1, -1], [2, -2], [3, -3]],  #
                    [[4, -4], [0, 0], [0, 0]],  #
                ],
                [
                    [[11, 21], [12, 22], [13, 23]],  #
                    [[14, 24], [0, 0], [0, 0]],  #
                ]
            ],
            tensor_utils.split_into_blocks(tensor, block_len=3, axis=1))

        self.assertAllEqual(
            [
                [
                    [[1, -1, 0]],  #
                    [[2, -2, 0]],  #
                    [[3, -3, 0]],  #
                    [[4, -4, 0]],  #
                ],
                [
                    [[11, 21, 0]],  #
                    [[12, 22, 0]],  #
                    [[13, 23, 0]],  #
                    [[14, 24, 0]],  #
                ],
            ],
            tensor_utils.split_into_blocks(tensor, block_len=3, axis=-1))
Esempio n. 2
0
  def test_split_into_blocks_1d(self):
    tensor = tf.range(6) + 1

    self.assertAllEqual([[1, 2], [3, 4], [5, 6]],
                        tensor_utils.split_into_blocks(
                            tensor, block_len=2, axis=0))

    self.assertAllEqual([[1, 2, 3], [4, 5, 6]],
                        tensor_utils.split_into_blocks(
                            tensor, block_len=3, axis=0))

    self.assertAllEqual([[1, 2, 3, 4], [5, 6, 0, 0]],
                        tensor_utils.split_into_blocks(
                            tensor, block_len=4, axis=0))

    self.assertAllEqual([[1, 2, 3, 4, 5], [6, 0, 0, 0, 0]],
                        tensor_utils.split_into_blocks(
                            tensor, block_len=5, axis=0))

    self.assertAllEqual([[1, 2, 3, 4, 5, 6]],
                        tensor_utils.split_into_blocks(
                            tensor, block_len=6, axis=0))

    self.assertAllEqual([[1, 2, 3, 4, 5, 6, 0]],
                        tensor_utils.split_into_blocks(
                            tensor, block_len=7, axis=0))

    self.assertAllEqual([[1, 2, 3, 4, 5, 6, -1, -1]],
                        tensor_utils.split_into_blocks(
                            tensor, block_len=8, axis=0, pad_value=-1))
Esempio n. 3
0
  def test_split_into_blocks_static_shape(self):
    # We use `placeholder_with_default` to simulate the TF v1 situation where
    # a static `batch_size` is unknown.
    tensor = tf.compat.v1.placeholder_with_default(
        np.ones(shape=[2, 5], dtype=np.int32), shape=[None, 5])

    result = tensor_utils.split_into_blocks(tensor, block_len=3, axis=-1)

    static_batch_size = tensor.shape.as_list()[0]
    self.assertAllEqual([static_batch_size, 2, 3], result.shape.as_list())
def make_fixed_block_side_inputs(
    input_mask: tf.Tensor,
    num_tokens_per_block: int,
    local_radius: int,
    relative_pos_max_distance: int,
    use_hard_g2l_mask: bool = False,
    use_hard_l2g_mask: bool = False,
    global_token_id: int = 1,
    name: Optional[Text] = None
) -> Tuple[GlobalLocalTransformerSideInputs, tf.Tensor]:
    """Utility for creating side inputs in a "fixed blocks" pattern.

  The "fixed blocks" experiments for NQ and OpenKP are implemented via example
  generation rather than using this function, but we include this function
  to illustrate how side inputs can be generated given just a BERT-style
  `input_mask` feature.  The corresponding global tokens are generated
  as part of this function too, so no global features are required as input.

  Args:
    input_mask: <int32>[batch_size, long_seq_len] Tensor of 1 and 0 values, with
      1 for actual tokens and 0 for padding.  This is the same format as
      original BERT.  `long_seq_len` must be statically known.
    num_tokens_per_block: Positive integer number of long tokens to assign to
      each global token.  For pre-training on the original BERT data (which was
      also used for ETC pre-training), the dataset implied a value of about 27,
      but values like 16 or 32 would also be reasonable.
    local_radius: How many tokens to the left/right for input tokens to locally
      self-attend to.  For example, a value of 1 would allow each token to only
      attend to 1 token to the left and 1 token to the right of it.
    relative_pos_max_distance: Maximum distance to use for relative position
      representations.  All larger distances will be clipped to this value. Use
      0 to skip relative position representations entirely.
    use_hard_g2l_mask: If True, global tokens only attend to tokens of their
      corresponding block in the long input.  If False, global tokens attend to
      all non-padding long tokens.  False is the default setup.
    use_hard_l2g_mask: If True, long tokens only attend to the global token
      corresponding to their block.  If False, long tokens attend to all the
      non-padding global tokens.  False is the default setup.
    global_token_id: Integer id to use for global tokens.  The default is `1`,
      which was the value used during ETC pre-training.
    name: A name for the operation (optional).

  Returns:
    A tuple with the following 2 elements:
      side_inputs: A `GlobalLocalTransformerSideInputs` object containing all
        side input tensors.
      global_token_ids: <int32>[batch_size, global_seq_len] Tensor of global
        tokens ids suitable to pass into `EtcModel`.  All global tokens will
        use the same `global_token_id`, except for padding tokens.
  """
    if num_tokens_per_block <= 0:
        raise ValueError('`num_tokens_per_block` must be positive.')

    with tf.name_scope(name or 'make_fixed_block_side_inputs'):
        input_mask = tf.convert_to_tensor(input_mask)

        batch_size = tensor_utils.get_shape_list(input_mask)[0]
        long_seq_len = input_mask.shape.as_list()[1]
        if long_seq_len is None:
            raise ValueError('`long_seq_len` must be statically known.')

        global_seq_len = (long_seq_len + num_tokens_per_block -
                          1) // num_tokens_per_block

        # [batch_size, global_seq_len, num_tokens_per_block]
        blocked_input_mask = tensor_utils.split_into_blocks(
            input_mask, block_len=num_tokens_per_block, axis=-1)
        assert blocked_input_mask.shape.as_list()[1] == global_seq_len

        # [batch_size, global_seq_len]
        global_input_mask = tf.minimum(
            tf.reduce_max(blocked_input_mask, axis=-1), 1)

        # [long_seq_len]
        sentence_ids = tf.repeat(tf.range(global_seq_len, dtype=tf.int32),
                                 num_tokens_per_block)[:long_seq_len]

        # [batch_size, long_seq_len]
        sentence_ids = tf.broadcast_to(sentence_ids,
                                       [batch_size, long_seq_len])

        side_inputs = make_global_local_transformer_side_inputs_from_example_ids(
            long_example_ids=input_mask,
            global_example_ids=global_input_mask,
            sentence_ids=sentence_ids,
            local_radius=local_radius,
            relative_pos_max_distance=relative_pos_max_distance,
            use_hard_g2l_mask=use_hard_g2l_mask,
            use_hard_l2g_mask=use_hard_l2g_mask)
        global_token_ids = global_token_id * global_input_mask
        return side_inputs, global_token_ids
Esempio n. 5
0
def make_local_segmented_att_mask(segment_ids: tf.Tensor,
                                  local_radius: int,
                                  name: Optional[Text] = None) -> tf.Tensor:
    """Makes local attention mask preventing attention across different segments.

  Restricts local self-attention to attend within segments, such that tokens can
  only attend to local tokens from the same segment id. The tokens in a segment
  do not need to be contiguous, but attention is still constrained by
  `local_radius`. The output can be used as `l2l_att_mask` in
  `layers.GlobalLocalTransformerLayers` for example.

  Args:
    segment_ids: <int32>[batch_size, seq_len] Tensor of segment ids, all of
      which must be non-negative.
    local_radius: The local radius as expected by
      `layers.GlobalLocalTransformerLayers`. Must be positive.
    name: A name for the operation (optional).

  Returns:
    <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask.
  """
    with tf.name_scope(name or 'make_local_segmented_att_mask'):
        segment_ids = tf.convert_to_tensor(segment_ids)

        if segment_ids.shape.rank != 2:
            raise ValueError('`segment_ids` must be a 2-D tensor.')

        batch_size, seq_len = tensor_utils.get_shape_list(segment_ids)

        # Add 1 so that segment id `0` doesn't coincide with `0` padding values
        # introduced later by `tensor_utils.concat_3_blocks()` for example.
        segment_ids += 1

        # [batch_size, num_blocks, local_radius]
        blocked_segment_ids = tensor_utils.split_into_blocks(
            segment_ids, block_len=local_radius, axis=1)

        # [batch_size, num_blocks, 3*local_radius]
        concat_blocked_segment_ids = tensor_utils.concat_3_blocks(
            blocked_segment_ids)

        # [batch_size, num_blocks, local_radius, 3*local_radius]
        tiled_segment_ids = tf.tile(
            concat_blocked_segment_ids[:, :, tf.newaxis, :],
            [1, 1, local_radius, 1])

        # [batch_size, num_blocks, local_radius, 2*local_radius + 1]
        blocked_unskewed_segment_ids = tensor_utils.unskew_elements_right(
            tiled_segment_ids, axis=-1)

        # [batch_size, num_blocks * local_radius, 2*local_radius + 1]
        flat_unskewed_segment_ids = tensor_utils.flatten_dims(
            blocked_unskewed_segment_ids, first_dim=1, last_dim=2)

        # [batch_size, seq_len, 2*local_radius + 1]
        unskewed_segment_ids = tf.slice(flat_unskewed_segment_ids,
                                        begin=[0, 0, 0],
                                        size=[-1, seq_len, -1])

        # [batch_size, seq_len, 1]
        center_token_segment_id = unskewed_segment_ids[:, :, local_radius:(
            local_radius + 1)]

        # [batch_size, seq_len, 2*local_radius + 1]
        result = tf.cast(
            tf.equal(unskewed_segment_ids, center_token_segment_id), tf.int32)

        # Use `reshape` to set the static shape when known.
        return tf.reshape(result, [batch_size, seq_len, 2 * local_radius + 1])