def make_ngram_labels(label_start_idx: tf.Tensor, label_phrase_len: tf.Tensor, long_max_length: int, kp_max_length: int, additive_smoothing_mass: float = 1e-6) -> tf.Tensor: """Makes ngram labels for `tf.nn.softmax_cross_entropy_with_logits`. Args: label_start_idx: <int32>[batch_size, num_labels] Tensor of the index of the first word in each correct key phrase. There must be at least 1 correct key phrase, and if there are less than `num_labels` then `-1` is used as right padding. All values must be less than `long_max_length`. `num_labels` is the maximum number of key phrase labels, which is 3 for OpenKP. label_phrase_len: <int32>[batch_size, num_labels] Tensor of the corresponding length of the key phrase, again using `-1` to pad. Non-padding values must be in the inclusive range [1, kp_max_length]. long_max_length: Integer maximum number of words in the document. kp_max_length: Integer maximum number of words in a key phrase. additive_smoothing_mass: Total probability mass (on top of `1.0` mass for the actual label) to add for additive smoothing. We use a minimum of 1e-6 to avoid any potential division by 0. Returns: <float32>[batch_size, kp_max_length * long_max_length] Tensor of label probabilities based on the inputs. Each row sums to 1.0, and the order of entries is compatible with reshaping `ngram_logits` from shape [batch_size, kp_max_length, long_max_length] to match these labels. """ # [batch_size, num_labels, kp_max_length] phrase_len_one_hot = tf.one_hot(label_phrase_len - 1, depth=kp_max_length, dtype=tf.float32) # [batch_size, num_labels, long_max_length] start_idx_one_hot = tf.one_hot(label_start_idx, depth=long_max_length, dtype=tf.float32) # [batch_size, kp_max_length, long_max_length] combined_one_hot = tf.einsum('bnk,bnl->bkl', phrase_len_one_hot, start_idx_one_hot) # [batch_size, kp_max_length * long_max_length] unnormalized_labels = tensor_utils.flatten_dims(combined_one_hot, first_dim=1) unsmoothed_labels = ( unnormalized_labels / (tf.reduce_sum(unnormalized_labels, axis=-1, keepdims=True) + 1e-6)) # Use at least 1e-6 smoothing mass to avoid divide by 0. additive_smoothing_mass = max(additive_smoothing_mass, 1e-6) num_classes = kp_max_length * long_max_length smoothed_labels = unsmoothed_labels + additive_smoothing_mass / num_classes return (smoothed_labels / tf.reduce_sum(smoothed_labels, axis=-1, keepdims=True))
def test_flatten_dims(self): tensor = tf.reshape(tf.range(2 * 3 * 4 * 5), [2, 3, 4, 5]) result1 = tensor_utils.flatten_dims(tensor, last_dim=1) self.assertAllEqual([6, 4, 5], result1.shape) self.assertAllEqual(tf.range(2 * 3 * 4 * 5), tf.reshape(result1, [-1])) self.assertAllEqual([2, 3, 20], tensor_utils.flatten_dims(tensor, first_dim=-2).shape) self.assertAllEqual([2, 12, 5], tensor_utils.flatten_dims(tensor, first_dim=1, last_dim=-2).shape) self.assertAllEqual([24, 5], tensor_utils.flatten_dims(tensor, last_dim=-2).shape) self.assertAllEqual([2 * 3 * 4 * 5], tensor_utils.flatten_dims(tensor).shape) self.assertAllEqual([2, 3, 4, 5], tensor_utils.flatten_dims(tensor, first_dim=1, last_dim=-3).shape) self.assertAllEqual([7], tensor_utils.flatten_dims(tf.ones([7])).shape) self.assertAllEqual([12], tensor_utils.flatten_dims(tf.ones([4, 3])).shape) with self.assertRaises(ValueError): tensor_utils.flatten_dims(tensor, first_dim=4) with self.assertRaises(ValueError): tensor_utils.flatten_dims(tensor, last_dim=-5) with self.assertRaises(ValueError): tensor_utils.flatten_dims(tensor, first_dim=2, last_dim=1)
def make_local_segmented_att_mask(segment_ids: tf.Tensor, local_radius: int, name: Optional[Text] = None) -> tf.Tensor: """Makes local attention mask preventing attention across different segments. Restricts local self-attention to attend within segments, such that tokens can only attend to local tokens from the same segment id. The tokens in a segment do not need to be contiguous, but attention is still constrained by `local_radius`. The output can be used as `l2l_att_mask` in `layers.GlobalLocalTransformerLayers` for example. Args: segment_ids: <int32>[batch_size, seq_len] Tensor of segment ids, all of which must be non-negative. local_radius: The local radius as expected by `layers.GlobalLocalTransformerLayers`. Must be positive. name: A name for the operation (optional). Returns: <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask. """ with tf.name_scope(name or 'make_local_segmented_att_mask'): segment_ids = tf.convert_to_tensor(segment_ids) if segment_ids.shape.rank != 2: raise ValueError('`segment_ids` must be a 2-D tensor.') batch_size, seq_len = tensor_utils.get_shape_list(segment_ids) # Add 1 so that segment id `0` doesn't coincide with `0` padding values # introduced later by `tensor_utils.concat_3_blocks()` for example. segment_ids += 1 # [batch_size, num_blocks, local_radius] blocked_segment_ids = tensor_utils.split_into_blocks( segment_ids, block_len=local_radius, axis=1) # [batch_size, num_blocks, 3*local_radius] concat_blocked_segment_ids = tensor_utils.concat_3_blocks( blocked_segment_ids) # [batch_size, num_blocks, local_radius, 3*local_radius] tiled_segment_ids = tf.tile( concat_blocked_segment_ids[:, :, tf.newaxis, :], [1, 1, local_radius, 1]) # [batch_size, num_blocks, local_radius, 2*local_radius + 1] blocked_unskewed_segment_ids = tensor_utils.unskew_elements_right( tiled_segment_ids, axis=-1) # [batch_size, num_blocks * local_radius, 2*local_radius + 1] flat_unskewed_segment_ids = tensor_utils.flatten_dims( blocked_unskewed_segment_ids, first_dim=1, last_dim=2) # [batch_size, seq_len, 2*local_radius + 1] unskewed_segment_ids = tf.slice(flat_unskewed_segment_ids, begin=[0, 0, 0], size=[-1, seq_len, -1]) # [batch_size, seq_len, 1] center_token_segment_id = unskewed_segment_ids[:, :, local_radius:( local_radius + 1)] # [batch_size, seq_len, 2*local_radius + 1] result = tf.cast( tf.equal(unskewed_segment_ids, center_token_segment_id), tf.int32) # Use `reshape` to set the static shape when known. return tf.reshape(result, [batch_size, seq_len, 2 * local_radius + 1])