def test_split_into_blocks_3d(self): # shape: [2, 4, 2] tensor = tf.constant([ [[1, -1], [2, -2], [3, -3], [4, -4]], # [[11, 21], [12, 22], [13, 23], [14, 24]] ]) self.assertAllEqual( [ [ [[1, -1], [2, -2]], # [[3, -3], [4, -4]], # ], [ [[11, 21], [12, 22]], # [[13, 23], [14, 24]], # ] ], tensor_utils.split_into_blocks(tensor, block_len=2, axis=-2)) self.assertAllEqual( [ [ [[1, -1], [2, -2], [3, -3]], # [[4, -4], [0, 0], [0, 0]], # ], [ [[11, 21], [12, 22], [13, 23]], # [[14, 24], [0, 0], [0, 0]], # ] ], tensor_utils.split_into_blocks(tensor, block_len=3, axis=1)) self.assertAllEqual( [ [ [[1, -1, 0]], # [[2, -2, 0]], # [[3, -3, 0]], # [[4, -4, 0]], # ], [ [[11, 21, 0]], # [[12, 22, 0]], # [[13, 23, 0]], # [[14, 24, 0]], # ], ], tensor_utils.split_into_blocks(tensor, block_len=3, axis=-1))
def test_split_into_blocks_1d(self): tensor = tf.range(6) + 1 self.assertAllEqual([[1, 2], [3, 4], [5, 6]], tensor_utils.split_into_blocks( tensor, block_len=2, axis=0)) self.assertAllEqual([[1, 2, 3], [4, 5, 6]], tensor_utils.split_into_blocks( tensor, block_len=3, axis=0)) self.assertAllEqual([[1, 2, 3, 4], [5, 6, 0, 0]], tensor_utils.split_into_blocks( tensor, block_len=4, axis=0)) self.assertAllEqual([[1, 2, 3, 4, 5], [6, 0, 0, 0, 0]], tensor_utils.split_into_blocks( tensor, block_len=5, axis=0)) self.assertAllEqual([[1, 2, 3, 4, 5, 6]], tensor_utils.split_into_blocks( tensor, block_len=6, axis=0)) self.assertAllEqual([[1, 2, 3, 4, 5, 6, 0]], tensor_utils.split_into_blocks( tensor, block_len=7, axis=0)) self.assertAllEqual([[1, 2, 3, 4, 5, 6, -1, -1]], tensor_utils.split_into_blocks( tensor, block_len=8, axis=0, pad_value=-1))
def test_split_into_blocks_static_shape(self): # We use `placeholder_with_default` to simulate the TF v1 situation where # a static `batch_size` is unknown. tensor = tf.compat.v1.placeholder_with_default( np.ones(shape=[2, 5], dtype=np.int32), shape=[None, 5]) result = tensor_utils.split_into_blocks(tensor, block_len=3, axis=-1) static_batch_size = tensor.shape.as_list()[0] self.assertAllEqual([static_batch_size, 2, 3], result.shape.as_list())
def make_fixed_block_side_inputs( input_mask: tf.Tensor, num_tokens_per_block: int, local_radius: int, relative_pos_max_distance: int, use_hard_g2l_mask: bool = False, use_hard_l2g_mask: bool = False, global_token_id: int = 1, name: Optional[Text] = None ) -> Tuple[GlobalLocalTransformerSideInputs, tf.Tensor]: """Utility for creating side inputs in a "fixed blocks" pattern. The "fixed blocks" experiments for NQ and OpenKP are implemented via example generation rather than using this function, but we include this function to illustrate how side inputs can be generated given just a BERT-style `input_mask` feature. The corresponding global tokens are generated as part of this function too, so no global features are required as input. Args: input_mask: <int32>[batch_size, long_seq_len] Tensor of 1 and 0 values, with 1 for actual tokens and 0 for padding. This is the same format as original BERT. `long_seq_len` must be statically known. num_tokens_per_block: Positive integer number of long tokens to assign to each global token. For pre-training on the original BERT data (which was also used for ETC pre-training), the dataset implied a value of about 27, but values like 16 or 32 would also be reasonable. local_radius: How many tokens to the left/right for input tokens to locally self-attend to. For example, a value of 1 would allow each token to only attend to 1 token to the left and 1 token to the right of it. relative_pos_max_distance: Maximum distance to use for relative position representations. All larger distances will be clipped to this value. Use 0 to skip relative position representations entirely. use_hard_g2l_mask: If True, global tokens only attend to tokens of their corresponding block in the long input. If False, global tokens attend to all non-padding long tokens. False is the default setup. use_hard_l2g_mask: If True, long tokens only attend to the global token corresponding to their block. If False, long tokens attend to all the non-padding global tokens. False is the default setup. global_token_id: Integer id to use for global tokens. The default is `1`, which was the value used during ETC pre-training. name: A name for the operation (optional). Returns: A tuple with the following 2 elements: side_inputs: A `GlobalLocalTransformerSideInputs` object containing all side input tensors. global_token_ids: <int32>[batch_size, global_seq_len] Tensor of global tokens ids suitable to pass into `EtcModel`. All global tokens will use the same `global_token_id`, except for padding tokens. """ if num_tokens_per_block <= 0: raise ValueError('`num_tokens_per_block` must be positive.') with tf.name_scope(name or 'make_fixed_block_side_inputs'): input_mask = tf.convert_to_tensor(input_mask) batch_size = tensor_utils.get_shape_list(input_mask)[0] long_seq_len = input_mask.shape.as_list()[1] if long_seq_len is None: raise ValueError('`long_seq_len` must be statically known.') global_seq_len = (long_seq_len + num_tokens_per_block - 1) // num_tokens_per_block # [batch_size, global_seq_len, num_tokens_per_block] blocked_input_mask = tensor_utils.split_into_blocks( input_mask, block_len=num_tokens_per_block, axis=-1) assert blocked_input_mask.shape.as_list()[1] == global_seq_len # [batch_size, global_seq_len] global_input_mask = tf.minimum( tf.reduce_max(blocked_input_mask, axis=-1), 1) # [long_seq_len] sentence_ids = tf.repeat(tf.range(global_seq_len, dtype=tf.int32), num_tokens_per_block)[:long_seq_len] # [batch_size, long_seq_len] sentence_ids = tf.broadcast_to(sentence_ids, [batch_size, long_seq_len]) side_inputs = make_global_local_transformer_side_inputs_from_example_ids( long_example_ids=input_mask, global_example_ids=global_input_mask, sentence_ids=sentence_ids, local_radius=local_radius, relative_pos_max_distance=relative_pos_max_distance, use_hard_g2l_mask=use_hard_g2l_mask, use_hard_l2g_mask=use_hard_l2g_mask) global_token_ids = global_token_id * global_input_mask return side_inputs, global_token_ids
def make_local_segmented_att_mask(segment_ids: tf.Tensor, local_radius: int, name: Optional[Text] = None) -> tf.Tensor: """Makes local attention mask preventing attention across different segments. Restricts local self-attention to attend within segments, such that tokens can only attend to local tokens from the same segment id. The tokens in a segment do not need to be contiguous, but attention is still constrained by `local_radius`. The output can be used as `l2l_att_mask` in `layers.GlobalLocalTransformerLayers` for example. Args: segment_ids: <int32>[batch_size, seq_len] Tensor of segment ids, all of which must be non-negative. local_radius: The local radius as expected by `layers.GlobalLocalTransformerLayers`. Must be positive. name: A name for the operation (optional). Returns: <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask. """ with tf.name_scope(name or 'make_local_segmented_att_mask'): segment_ids = tf.convert_to_tensor(segment_ids) if segment_ids.shape.rank != 2: raise ValueError('`segment_ids` must be a 2-D tensor.') batch_size, seq_len = tensor_utils.get_shape_list(segment_ids) # Add 1 so that segment id `0` doesn't coincide with `0` padding values # introduced later by `tensor_utils.concat_3_blocks()` for example. segment_ids += 1 # [batch_size, num_blocks, local_radius] blocked_segment_ids = tensor_utils.split_into_blocks( segment_ids, block_len=local_radius, axis=1) # [batch_size, num_blocks, 3*local_radius] concat_blocked_segment_ids = tensor_utils.concat_3_blocks( blocked_segment_ids) # [batch_size, num_blocks, local_radius, 3*local_radius] tiled_segment_ids = tf.tile( concat_blocked_segment_ids[:, :, tf.newaxis, :], [1, 1, local_radius, 1]) # [batch_size, num_blocks, local_radius, 2*local_radius + 1] blocked_unskewed_segment_ids = tensor_utils.unskew_elements_right( tiled_segment_ids, axis=-1) # [batch_size, num_blocks * local_radius, 2*local_radius + 1] flat_unskewed_segment_ids = tensor_utils.flatten_dims( blocked_unskewed_segment_ids, first_dim=1, last_dim=2) # [batch_size, seq_len, 2*local_radius + 1] unskewed_segment_ids = tf.slice(flat_unskewed_segment_ids, begin=[0, 0, 0], size=[-1, seq_len, -1]) # [batch_size, seq_len, 1] center_token_segment_id = unskewed_segment_ids[:, :, local_radius:( local_radius + 1)] # [batch_size, seq_len, 2*local_radius + 1] result = tf.cast( tf.equal(unskewed_segment_ids, center_token_segment_id), tf.int32) # Use `reshape` to set the static shape when known. return tf.reshape(result, [batch_size, seq_len, 2 * local_radius + 1])