def test_skew_elements_right_unskew_elements_right_round_trip(self):
        np.random.seed(1234)

        tensor = tf.constant(np.random.normal(size=[3, 5, 7, 11]))

        self.assertAllEqual(
            tensor,
            tensor_utils.unskew_elements_right(
                tensor_utils.skew_elements_right(tensor, 1), 1))

        self.assertAllEqual(
            tensor,
            tensor_utils.unskew_elements_right(
                tensor_utils.skew_elements_right(tensor, 2), 2))

        self.assertAllEqual(
            tensor,
            tensor_utils.unskew_elements_right(
                tensor_utils.skew_elements_right(tensor, -1), -1))
    def make_relative_att_ids(self,
                              seq_len: Union[int, tf.Tensor],
                              batch_size: Optional[Union[int, tf.Tensor]] = 1,
                              name: Optional[Text] = None) -> tf.Tensor:
        """Makes relative position ids for full self-attention.

    For example, if `max_distance` is 3, `ignore_direction` is False, `seq_len`
    is 6, and `batch_size` is 1, the result is the following:
      [[
          [0, 1, 2, 3, 3, 3],
          [4, 0, 1, 2, 3, 3],
          [5, 4, 0, 1, 2, 3],
          [6, 5, 4, 0, 1, 2],
          [6, 6, 5, 4, 0, 1],
          [6, 6, 6, 5, 4, 0],
      ]]

    Args:
      seq_len: The sequence length to create ids for. Must be positive. If a
        Tensor, must be a scalar int.
      batch_size: The batch size of the result (default 1). Must be positive. If
        a Tensor, must be a scalar int. All examples in the batch will have the
        same id pattern.
      name: A name for the operation (optional).

    Returns:
      <int32>[batch_size, seq_len, seq_len] Tensor of relative position ids.
    """
        with tf.name_scope(name or 'make_relative_att_ids'):
            if isinstance(seq_len, int) and seq_len < 1:
                raise ValueError('`seq_len` must be positive.')
            if isinstance(batch_size, int) and batch_size < 1:
                raise ValueError('`batch_size` must be positive.')

            # We need the id_pattern to cover all tokens to the left of the last token
            # and all tokens to the right of the first token at the same time.
            window_size = 2 * seq_len - 1

            # [window_size]
            id_pattern = self._make_relative_id_pattern(window_size)

            # [seq_len, window_size]
            id_tensor = tf.tile(id_pattern[tf.newaxis, :], [seq_len, 1])

            # [seq_len, window_size + seq_len - 1]
            id_tensor = tensor_utils.skew_elements_right(id_tensor, -1)

            # [seq_len, seq_len]
            id_tensor = tf.slice(id_tensor, [0, seq_len - 1],
                                 [seq_len, seq_len])

            return tf.tile(id_tensor[tf.newaxis, :, :], [batch_size, 1, 1])
    def test_skew_elements_right_4d(self):
        # shape: [2, 3, 2, 2]
        tensor = tf.constant([
            [
                [[1, -1], [2, -2]],  #
                [[3, -3], [4, -4]],  #
                [[5, -5], [6, -6]],  #
            ],  #
            [
                [[.1, -.1], [.2, -.2]],  #
                [[.3, -.3], [.4, -.4]],  #
                [[.5, -.5], [.6, -.6]],  #
            ],  #
        ])

        self.assertAllClose(
            [
                [
                    [[1, -1, 0], [0, 2, -2]],  #
                    [[3, -3, 0], [0, 4, -4]],  #
                    [[5, -5, 0], [0, 6, -6]],  #
                ],  #
                [
                    [[.1, -.1, 0], [0, .2, -.2]],  #
                    [[.3, -.3, 0], [0, .4, -.4]],  #
                    [[.5, -.5, 0], [0, .6, -.6]],  #
                ],  #
            ],
            tensor_utils.skew_elements_right(tensor, -1))

        self.assertAllClose(
            [
                [
                    [[1, -1], [2, -2], [0, 0], [0, 0]],  #
                    [[0, 0], [3, -3], [4, -4], [0, 0]],  #
                    [[0, 0], [0, 0], [5, -5], [6, -6]],  #
                ],  #
                [
                    [[.1, -.1], [.2, -.2], [0, 0], [0, 0]],  #
                    [[0, 0], [.3, -.3], [.4, -.4], [0, 0]],  #
                    [[0, 0], [0, 0], [.5, -.5], [.6, -.6]],  #
                ],  #
            ],
            tensor_utils.skew_elements_right(tensor, -2))

        self.assertAllClose(
            [
                [
                    [[1, -1], [2, -2]],  #
                    [[3, -3], [4, -4]],  #
                    [[5, -5], [6, -6]],  #
                    [[0, 0], [0, 0]],  #
                ],  #
                [
                    [[0, 0], [0, 0]],  #
                    [[.1, -.1], [.2, -.2]],  #
                    [[.3, -.3], [.4, -.4]],  #
                    [[.5, -.5], [.6, -.6]],  #
                ],  #
            ],
            tensor_utils.skew_elements_right(tensor, 1))
    def test_skew_elements_right_2d(self):
        tensor = tf.constant([
            [1, 2, 3],  #
            [4, 5, 6],  #
            [7, 8, 9],  #
        ])

        self.assertAllEqual(
            [
                [1, 2, 3, 0, 0],  #
                [0, 4, 5, 6, 0],  #
                [0, 0, 7, 8, 9],  #
            ],
            tensor_utils.skew_elements_right(tensor, -1))

        self.assertAllEqual(
            [
                [1, 2, 3, -2, -2],  #
                [-2, 4, 5, 6, -2],  #
                [-2, -2, 7, 8, 9],  #
            ],
            tensor_utils.skew_elements_right(tensor, 1, pad_value=-2))

        with self.assertRaises(ValueError):
            tensor_utils.skew_elements_right(tensor, 0)

        self.assertAllEqual(
            [
                [1, 2, 0, 0, 0, 0],  #
                [0, 3, 4, 0, 0, 0],  #
                [0, 0, 5, 6, 0, 0],  #
                [0, 0, 0, 7, 8, 0],  #
                [0, 0, 0, 0, 9, 10],  #
            ],
            tensor_utils.skew_elements_right(
                [
                    [1, 2],  #
                    [3, 4],  #
                    [5, 6],  #
                    [7, 8],  #
                    [9, 10],  #
                ],
                1))

        self.assertAllEqual(
            [
                [1, 2, 3, 4, 0],  #
                [0, 5, 6, 7, 8],  #
            ],
            tensor_utils.skew_elements_right(
                [
                    [1, 2, 3, 4],  #
                    [5, 6, 7, 8],  #
                ],
                1))

        self.assertAllEqual(
            [
                [1, 2, 3],  #
            ],
            tensor_utils.skew_elements_right(
                [
                    [1, 2, 3],  #
                ],
                1))