Ejemplo n.º 1
0
    def test_skew_elements_right_unskew_elements_right_round_trip(self):
        np.random.seed(1234)

        tensor = tf.constant(np.random.normal(size=[3, 5, 7, 11]))

        self.assertAllEqual(
            tensor,
            tensor_utils.unskew_elements_right(
                tensor_utils.skew_elements_right(tensor, 1), 1))

        self.assertAllEqual(
            tensor,
            tensor_utils.unskew_elements_right(
                tensor_utils.skew_elements_right(tensor, 2), 2))

        self.assertAllEqual(
            tensor,
            tensor_utils.unskew_elements_right(
                tensor_utils.skew_elements_right(tensor, -1), -1))
Ejemplo n.º 2
0
    def test_unskew_elements_right_2d(self):
        tensor = tf.constant([
            [1, 2, 3, 0, 0],  #
            [0, 4, 5, 6, 0],  #
            [0, 0, 7, 8, 9],  #
        ])

        self.assertAllEqual(
            [
                [1, 2, 3],  #
                [4, 5, 6],  #
                [7, 8, 9],  #
            ],
            tensor_utils.unskew_elements_right(tensor, -1))

        with self.assertRaises(ValueError):
            tensor_utils.unskew_elements_right(tensor, 0)

        self.assertAllEqual(
            [
                [1, 2],  #
                [3, 4],  #
                [5, 6],  #
                [7, 8],  #
                [9, 10],  #
            ],
            tensor_utils.unskew_elements_right(
                [
                    [1, 2, 0, 0, 0, 0],  #
                    [0, 3, 4, 0, 0, 0],  #
                    [0, 0, 5, 6, 0, 0],  #
                    [0, 0, 0, 7, 8, 0],  #
                    [0, 0, 0, 0, 9, 10],  #
                ],
                1))

        self.assertAllEqual(
            [
                [1, 2, 3, 4],  #
                [5, 6, 7, 8],  #
            ],
            tensor_utils.unskew_elements_right(
                [
                    [1, 2, 3, 4, 0],  #
                    [0, 5, 6, 7, 8],  #
                ],
                1))

        self.assertAllEqual(
            [
                [1, 2, 3],  #
            ],
            tensor_utils.unskew_elements_right(
                [
                    [1, 2, 3],  #
                ],
                1))

        self.assertAllEqual(
            [
                [1],  #
                [2],  #
                [3],  #
            ],
            tensor_utils.unskew_elements_right(
                [
                    [1, 0, 0],  #
                    [0, 2, 0],  #
                    [0, 0, 3],  #
                ],
                1))

        with self.assertRaises(ValueError):
            tensor_utils.unskew_elements_right(
                [
                    [1, 2],  #
                    [3, 4],  #
                    [5, 6],  #
                ],
                1)
Ejemplo n.º 3
0
    def test_unskew_elements_right_4d(self):
        # shape: [2, 3, 2, 2]
        expected_tensor = tf.constant([
            [
                [[1, -1], [2, -2]],  #
                [[3, -3], [4, -4]],  #
                [[5, -5], [6, -6]],  #
            ],  #
            [
                [[.1, -.1], [.2, -.2]],  #
                [[.3, -.3], [.4, -.4]],  #
                [[.5, -.5], [.6, -.6]],  #
            ],  #
        ])

        self.assertAllClose(
            expected_tensor,
            tensor_utils.unskew_elements_right(
                [
                    [
                        [[1, -1, 0], [0, 2, -2]],  #
                        [[3, -3, 0], [0, 4, -4]],  #
                        [[5, -5, 0], [0, 6, -6]],  #
                    ],  #
                    [
                        [[.1, -.1, 0], [0, .2, -.2]],  #
                        [[.3, -.3, 0], [0, .4, -.4]],  #
                        [[.5, -.5, 0], [0, .6, -.6]],  #
                    ],  #
                ],
                -1))

        self.assertAllClose(
            expected_tensor,
            tensor_utils.unskew_elements_right(
                [
                    [
                        [[1, -1], [2, -2], [0, 0], [0, 0]],  #
                        [[0, 0], [3, -3], [4, -4], [0, 0]],  #
                        [[0, 0], [0, 0], [5, -5], [6, -6]],  #
                    ],  #
                    [
                        [[.1, -.1], [.2, -.2], [0, 0], [0, 0]],  #
                        [[0, 0], [.3, -.3], [.4, -.4], [0, 0]],  #
                        [[0, 0], [0, 0], [.5, -.5], [.6, -.6]],  #
                    ],  #
                ],
                -2))

        self.assertAllClose(
            expected_tensor,
            tensor_utils.unskew_elements_right(
                [
                    [
                        [[1, -1], [2, -2]],  #
                        [[3, -3], [4, -4]],  #
                        [[5, -5], [6, -6]],  #
                        [[0, 0], [0, 0]],  #
                    ],  #
                    [
                        [[0, 0], [0, 0]],  #
                        [[.1, -.1], [.2, -.2]],  #
                        [[.3, -.3], [.4, -.4]],  #
                        [[.5, -.5], [.6, -.6]],  #
                    ],  #
                ],
                1))
Ejemplo n.º 4
0
def make_local_segmented_att_mask(segment_ids: tf.Tensor,
                                  local_radius: int,
                                  name: Optional[Text] = None) -> tf.Tensor:
    """Makes local attention mask preventing attention across different segments.

  Restricts local self-attention to attend within segments, such that tokens can
  only attend to local tokens from the same segment id. The tokens in a segment
  do not need to be contiguous, but attention is still constrained by
  `local_radius`. The output can be used as `l2l_att_mask` in
  `layers.GlobalLocalTransformerLayers` for example.

  Args:
    segment_ids: <int32>[batch_size, seq_len] Tensor of segment ids, all of
      which must be non-negative.
    local_radius: The local radius as expected by
      `layers.GlobalLocalTransformerLayers`. Must be positive.
    name: A name for the operation (optional).

  Returns:
    <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask.
  """
    with tf.name_scope(name or 'make_local_segmented_att_mask'):
        segment_ids = tf.convert_to_tensor(segment_ids)

        if segment_ids.shape.rank != 2:
            raise ValueError('`segment_ids` must be a 2-D tensor.')

        batch_size, seq_len = tensor_utils.get_shape_list(segment_ids)

        # Add 1 so that segment id `0` doesn't coincide with `0` padding values
        # introduced later by `tensor_utils.concat_3_blocks()` for example.
        segment_ids += 1

        # [batch_size, num_blocks, local_radius]
        blocked_segment_ids = tensor_utils.split_into_blocks(
            segment_ids, block_len=local_radius, axis=1)

        # [batch_size, num_blocks, 3*local_radius]
        concat_blocked_segment_ids = tensor_utils.concat_3_blocks(
            blocked_segment_ids)

        # [batch_size, num_blocks, local_radius, 3*local_radius]
        tiled_segment_ids = tf.tile(
            concat_blocked_segment_ids[:, :, tf.newaxis, :],
            [1, 1, local_radius, 1])

        # [batch_size, num_blocks, local_radius, 2*local_radius + 1]
        blocked_unskewed_segment_ids = tensor_utils.unskew_elements_right(
            tiled_segment_ids, axis=-1)

        # [batch_size, num_blocks * local_radius, 2*local_radius + 1]
        flat_unskewed_segment_ids = tensor_utils.flatten_dims(
            blocked_unskewed_segment_ids, first_dim=1, last_dim=2)

        # [batch_size, seq_len, 2*local_radius + 1]
        unskewed_segment_ids = tf.slice(flat_unskewed_segment_ids,
                                        begin=[0, 0, 0],
                                        size=[-1, seq_len, -1])

        # [batch_size, seq_len, 1]
        center_token_segment_id = unskewed_segment_ids[:, :, local_radius:(
            local_radius + 1)]

        # [batch_size, seq_len, 2*local_radius + 1]
        result = tf.cast(
            tf.equal(unskewed_segment_ids, center_token_segment_id), tf.int32)

        # Use `reshape` to set the static shape when known.
        return tf.reshape(result, [batch_size, seq_len, 2 * local_radius + 1])