def test_skew_elements_right_unskew_elements_right_round_trip(self): np.random.seed(1234) tensor = tf.constant(np.random.normal(size=[3, 5, 7, 11])) self.assertAllEqual( tensor, tensor_utils.unskew_elements_right( tensor_utils.skew_elements_right(tensor, 1), 1)) self.assertAllEqual( tensor, tensor_utils.unskew_elements_right( tensor_utils.skew_elements_right(tensor, 2), 2)) self.assertAllEqual( tensor, tensor_utils.unskew_elements_right( tensor_utils.skew_elements_right(tensor, -1), -1))
def test_unskew_elements_right_2d(self): tensor = tf.constant([ [1, 2, 3, 0, 0], # [0, 4, 5, 6, 0], # [0, 0, 7, 8, 9], # ]) self.assertAllEqual( [ [1, 2, 3], # [4, 5, 6], # [7, 8, 9], # ], tensor_utils.unskew_elements_right(tensor, -1)) with self.assertRaises(ValueError): tensor_utils.unskew_elements_right(tensor, 0) self.assertAllEqual( [ [1, 2], # [3, 4], # [5, 6], # [7, 8], # [9, 10], # ], tensor_utils.unskew_elements_right( [ [1, 2, 0, 0, 0, 0], # [0, 3, 4, 0, 0, 0], # [0, 0, 5, 6, 0, 0], # [0, 0, 0, 7, 8, 0], # [0, 0, 0, 0, 9, 10], # ], 1)) self.assertAllEqual( [ [1, 2, 3, 4], # [5, 6, 7, 8], # ], tensor_utils.unskew_elements_right( [ [1, 2, 3, 4, 0], # [0, 5, 6, 7, 8], # ], 1)) self.assertAllEqual( [ [1, 2, 3], # ], tensor_utils.unskew_elements_right( [ [1, 2, 3], # ], 1)) self.assertAllEqual( [ [1], # [2], # [3], # ], tensor_utils.unskew_elements_right( [ [1, 0, 0], # [0, 2, 0], # [0, 0, 3], # ], 1)) with self.assertRaises(ValueError): tensor_utils.unskew_elements_right( [ [1, 2], # [3, 4], # [5, 6], # ], 1)
def test_unskew_elements_right_4d(self): # shape: [2, 3, 2, 2] expected_tensor = tf.constant([ [ [[1, -1], [2, -2]], # [[3, -3], [4, -4]], # [[5, -5], [6, -6]], # ], # [ [[.1, -.1], [.2, -.2]], # [[.3, -.3], [.4, -.4]], # [[.5, -.5], [.6, -.6]], # ], # ]) self.assertAllClose( expected_tensor, tensor_utils.unskew_elements_right( [ [ [[1, -1, 0], [0, 2, -2]], # [[3, -3, 0], [0, 4, -4]], # [[5, -5, 0], [0, 6, -6]], # ], # [ [[.1, -.1, 0], [0, .2, -.2]], # [[.3, -.3, 0], [0, .4, -.4]], # [[.5, -.5, 0], [0, .6, -.6]], # ], # ], -1)) self.assertAllClose( expected_tensor, tensor_utils.unskew_elements_right( [ [ [[1, -1], [2, -2], [0, 0], [0, 0]], # [[0, 0], [3, -3], [4, -4], [0, 0]], # [[0, 0], [0, 0], [5, -5], [6, -6]], # ], # [ [[.1, -.1], [.2, -.2], [0, 0], [0, 0]], # [[0, 0], [.3, -.3], [.4, -.4], [0, 0]], # [[0, 0], [0, 0], [.5, -.5], [.6, -.6]], # ], # ], -2)) self.assertAllClose( expected_tensor, tensor_utils.unskew_elements_right( [ [ [[1, -1], [2, -2]], # [[3, -3], [4, -4]], # [[5, -5], [6, -6]], # [[0, 0], [0, 0]], # ], # [ [[0, 0], [0, 0]], # [[.1, -.1], [.2, -.2]], # [[.3, -.3], [.4, -.4]], # [[.5, -.5], [.6, -.6]], # ], # ], 1))
def make_local_segmented_att_mask(segment_ids: tf.Tensor, local_radius: int, name: Optional[Text] = None) -> tf.Tensor: """Makes local attention mask preventing attention across different segments. Restricts local self-attention to attend within segments, such that tokens can only attend to local tokens from the same segment id. The tokens in a segment do not need to be contiguous, but attention is still constrained by `local_radius`. The output can be used as `l2l_att_mask` in `layers.GlobalLocalTransformerLayers` for example. Args: segment_ids: <int32>[batch_size, seq_len] Tensor of segment ids, all of which must be non-negative. local_radius: The local radius as expected by `layers.GlobalLocalTransformerLayers`. Must be positive. name: A name for the operation (optional). Returns: <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask. """ with tf.name_scope(name or 'make_local_segmented_att_mask'): segment_ids = tf.convert_to_tensor(segment_ids) if segment_ids.shape.rank != 2: raise ValueError('`segment_ids` must be a 2-D tensor.') batch_size, seq_len = tensor_utils.get_shape_list(segment_ids) # Add 1 so that segment id `0` doesn't coincide with `0` padding values # introduced later by `tensor_utils.concat_3_blocks()` for example. segment_ids += 1 # [batch_size, num_blocks, local_radius] blocked_segment_ids = tensor_utils.split_into_blocks( segment_ids, block_len=local_radius, axis=1) # [batch_size, num_blocks, 3*local_radius] concat_blocked_segment_ids = tensor_utils.concat_3_blocks( blocked_segment_ids) # [batch_size, num_blocks, local_radius, 3*local_radius] tiled_segment_ids = tf.tile( concat_blocked_segment_ids[:, :, tf.newaxis, :], [1, 1, local_radius, 1]) # [batch_size, num_blocks, local_radius, 2*local_radius + 1] blocked_unskewed_segment_ids = tensor_utils.unskew_elements_right( tiled_segment_ids, axis=-1) # [batch_size, num_blocks * local_radius, 2*local_radius + 1] flat_unskewed_segment_ids = tensor_utils.flatten_dims( blocked_unskewed_segment_ids, first_dim=1, last_dim=2) # [batch_size, seq_len, 2*local_radius + 1] unskewed_segment_ids = tf.slice(flat_unskewed_segment_ids, begin=[0, 0, 0], size=[-1, seq_len, -1]) # [batch_size, seq_len, 1] center_token_segment_id = unskewed_segment_ids[:, :, local_radius:( local_radius + 1)] # [batch_size, seq_len, 2*local_radius + 1] result = tf.cast( tf.equal(unskewed_segment_ids, center_token_segment_id), tf.int32) # Use `reshape` to set the static shape when known. return tf.reshape(result, [batch_size, seq_len, 2 * local_radius + 1])