def make_ngram_labels(label_start_idx: tf.Tensor,
                      label_phrase_len: tf.Tensor,
                      long_max_length: int,
                      kp_max_length: int,
                      additive_smoothing_mass: float = 1e-6) -> tf.Tensor:
    """Makes ngram labels for `tf.nn.softmax_cross_entropy_with_logits`.

  Args:
    label_start_idx: <int32>[batch_size, num_labels] Tensor of the index of the
      first word in each correct key phrase. There must be at least 1 correct
      key phrase, and if there are less than `num_labels` then `-1` is used as
      right padding. All values must be less than `long_max_length`.
      `num_labels` is the maximum number of key phrase labels, which is 3 for
      OpenKP.
    label_phrase_len: <int32>[batch_size, num_labels] Tensor of the
      corresponding length of the key phrase, again using `-1` to pad.
      Non-padding values must be in the inclusive range [1, kp_max_length].
    long_max_length: Integer maximum number of words in the document.
    kp_max_length: Integer maximum number of words in a key phrase.
    additive_smoothing_mass: Total probability mass (on top of `1.0` mass for
      the actual label) to add for additive smoothing. We use a minimum of 1e-6
      to avoid any potential division by 0.

  Returns:
    <float32>[batch_size, kp_max_length * long_max_length] Tensor of label
    probabilities based on the inputs. Each row sums to 1.0, and the order
    of entries is compatible with reshaping `ngram_logits` from shape
    [batch_size, kp_max_length, long_max_length] to match these labels.
  """
    # [batch_size, num_labels, kp_max_length]
    phrase_len_one_hot = tf.one_hot(label_phrase_len - 1,
                                    depth=kp_max_length,
                                    dtype=tf.float32)

    # [batch_size, num_labels, long_max_length]
    start_idx_one_hot = tf.one_hot(label_start_idx,
                                   depth=long_max_length,
                                   dtype=tf.float32)

    # [batch_size, kp_max_length, long_max_length]
    combined_one_hot = tf.einsum('bnk,bnl->bkl', phrase_len_one_hot,
                                 start_idx_one_hot)

    # [batch_size, kp_max_length * long_max_length]
    unnormalized_labels = tensor_utils.flatten_dims(combined_one_hot,
                                                    first_dim=1)

    unsmoothed_labels = (
        unnormalized_labels /
        (tf.reduce_sum(unnormalized_labels, axis=-1, keepdims=True) + 1e-6))

    # Use at least 1e-6 smoothing mass to avoid divide by 0.
    additive_smoothing_mass = max(additive_smoothing_mass, 1e-6)

    num_classes = kp_max_length * long_max_length
    smoothed_labels = unsmoothed_labels + additive_smoothing_mass / num_classes

    return (smoothed_labels /
            tf.reduce_sum(smoothed_labels, axis=-1, keepdims=True))
    def test_flatten_dims(self):
        tensor = tf.reshape(tf.range(2 * 3 * 4 * 5), [2, 3, 4, 5])

        result1 = tensor_utils.flatten_dims(tensor, last_dim=1)
        self.assertAllEqual([6, 4, 5], result1.shape)
        self.assertAllEqual(tf.range(2 * 3 * 4 * 5), tf.reshape(result1, [-1]))

        self.assertAllEqual([2, 3, 20],
                            tensor_utils.flatten_dims(tensor,
                                                      first_dim=-2).shape)

        self.assertAllEqual([2, 12, 5],
                            tensor_utils.flatten_dims(tensor,
                                                      first_dim=1,
                                                      last_dim=-2).shape)

        self.assertAllEqual([24, 5],
                            tensor_utils.flatten_dims(tensor,
                                                      last_dim=-2).shape)

        self.assertAllEqual([2 * 3 * 4 * 5],
                            tensor_utils.flatten_dims(tensor).shape)

        self.assertAllEqual([2, 3, 4, 5],
                            tensor_utils.flatten_dims(tensor,
                                                      first_dim=1,
                                                      last_dim=-3).shape)

        self.assertAllEqual([7], tensor_utils.flatten_dims(tf.ones([7])).shape)

        self.assertAllEqual([12],
                            tensor_utils.flatten_dims(tf.ones([4, 3])).shape)

        with self.assertRaises(ValueError):
            tensor_utils.flatten_dims(tensor, first_dim=4)

        with self.assertRaises(ValueError):
            tensor_utils.flatten_dims(tensor, last_dim=-5)

        with self.assertRaises(ValueError):
            tensor_utils.flatten_dims(tensor, first_dim=2, last_dim=1)
def make_local_segmented_att_mask(segment_ids: tf.Tensor,
                                  local_radius: int,
                                  name: Optional[Text] = None) -> tf.Tensor:
    """Makes local attention mask preventing attention across different segments.

  Restricts local self-attention to attend within segments, such that tokens can
  only attend to local tokens from the same segment id. The tokens in a segment
  do not need to be contiguous, but attention is still constrained by
  `local_radius`. The output can be used as `l2l_att_mask` in
  `layers.GlobalLocalTransformerLayers` for example.

  Args:
    segment_ids: <int32>[batch_size, seq_len] Tensor of segment ids, all of
      which must be non-negative.
    local_radius: The local radius as expected by
      `layers.GlobalLocalTransformerLayers`. Must be positive.
    name: A name for the operation (optional).

  Returns:
    <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask.
  """
    with tf.name_scope(name or 'make_local_segmented_att_mask'):
        segment_ids = tf.convert_to_tensor(segment_ids)

        if segment_ids.shape.rank != 2:
            raise ValueError('`segment_ids` must be a 2-D tensor.')

        batch_size, seq_len = tensor_utils.get_shape_list(segment_ids)

        # Add 1 so that segment id `0` doesn't coincide with `0` padding values
        # introduced later by `tensor_utils.concat_3_blocks()` for example.
        segment_ids += 1

        # [batch_size, num_blocks, local_radius]
        blocked_segment_ids = tensor_utils.split_into_blocks(
            segment_ids, block_len=local_radius, axis=1)

        # [batch_size, num_blocks, 3*local_radius]
        concat_blocked_segment_ids = tensor_utils.concat_3_blocks(
            blocked_segment_ids)

        # [batch_size, num_blocks, local_radius, 3*local_radius]
        tiled_segment_ids = tf.tile(
            concat_blocked_segment_ids[:, :, tf.newaxis, :],
            [1, 1, local_radius, 1])

        # [batch_size, num_blocks, local_radius, 2*local_radius + 1]
        blocked_unskewed_segment_ids = tensor_utils.unskew_elements_right(
            tiled_segment_ids, axis=-1)

        # [batch_size, num_blocks * local_radius, 2*local_radius + 1]
        flat_unskewed_segment_ids = tensor_utils.flatten_dims(
            blocked_unskewed_segment_ids, first_dim=1, last_dim=2)

        # [batch_size, seq_len, 2*local_radius + 1]
        unskewed_segment_ids = tf.slice(flat_unskewed_segment_ids,
                                        begin=[0, 0, 0],
                                        size=[-1, seq_len, -1])

        # [batch_size, seq_len, 1]
        center_token_segment_id = unskewed_segment_ids[:, :, local_radius:(
            local_radius + 1)]

        # [batch_size, seq_len, 2*local_radius + 1]
        result = tf.cast(
            tf.equal(unskewed_segment_ids, center_token_segment_id), tf.int32)

        # Use `reshape` to set the static shape when known.
        return tf.reshape(result, [batch_size, seq_len, 2 * local_radius + 1])