def test_skew_elements_right_unskew_elements_right_round_trip(self): np.random.seed(1234) tensor = tf.constant(np.random.normal(size=[3, 5, 7, 11])) self.assertAllEqual( tensor, tensor_utils.unskew_elements_right( tensor_utils.skew_elements_right(tensor, 1), 1)) self.assertAllEqual( tensor, tensor_utils.unskew_elements_right( tensor_utils.skew_elements_right(tensor, 2), 2)) self.assertAllEqual( tensor, tensor_utils.unskew_elements_right( tensor_utils.skew_elements_right(tensor, -1), -1))
def make_relative_att_ids(self, seq_len: Union[int, tf.Tensor], batch_size: Optional[Union[int, tf.Tensor]] = 1, name: Optional[Text] = None) -> tf.Tensor: """Makes relative position ids for full self-attention. For example, if `max_distance` is 3, `ignore_direction` is False, `seq_len` is 6, and `batch_size` is 1, the result is the following: [[ [0, 1, 2, 3, 3, 3], [4, 0, 1, 2, 3, 3], [5, 4, 0, 1, 2, 3], [6, 5, 4, 0, 1, 2], [6, 6, 5, 4, 0, 1], [6, 6, 6, 5, 4, 0], ]] Args: seq_len: The sequence length to create ids for. Must be positive. If a Tensor, must be a scalar int. batch_size: The batch size of the result (default 1). Must be positive. If a Tensor, must be a scalar int. All examples in the batch will have the same id pattern. name: A name for the operation (optional). Returns: <int32>[batch_size, seq_len, seq_len] Tensor of relative position ids. """ with tf.name_scope(name or 'make_relative_att_ids'): if isinstance(seq_len, int) and seq_len < 1: raise ValueError('`seq_len` must be positive.') if isinstance(batch_size, int) and batch_size < 1: raise ValueError('`batch_size` must be positive.') # We need the id_pattern to cover all tokens to the left of the last token # and all tokens to the right of the first token at the same time. window_size = 2 * seq_len - 1 # [window_size] id_pattern = self._make_relative_id_pattern(window_size) # [seq_len, window_size] id_tensor = tf.tile(id_pattern[tf.newaxis, :], [seq_len, 1]) # [seq_len, window_size + seq_len - 1] id_tensor = tensor_utils.skew_elements_right(id_tensor, -1) # [seq_len, seq_len] id_tensor = tf.slice(id_tensor, [0, seq_len - 1], [seq_len, seq_len]) return tf.tile(id_tensor[tf.newaxis, :, :], [batch_size, 1, 1])
def test_skew_elements_right_4d(self): # shape: [2, 3, 2, 2] tensor = tf.constant([ [ [[1, -1], [2, -2]], # [[3, -3], [4, -4]], # [[5, -5], [6, -6]], # ], # [ [[.1, -.1], [.2, -.2]], # [[.3, -.3], [.4, -.4]], # [[.5, -.5], [.6, -.6]], # ], # ]) self.assertAllClose( [ [ [[1, -1, 0], [0, 2, -2]], # [[3, -3, 0], [0, 4, -4]], # [[5, -5, 0], [0, 6, -6]], # ], # [ [[.1, -.1, 0], [0, .2, -.2]], # [[.3, -.3, 0], [0, .4, -.4]], # [[.5, -.5, 0], [0, .6, -.6]], # ], # ], tensor_utils.skew_elements_right(tensor, -1)) self.assertAllClose( [ [ [[1, -1], [2, -2], [0, 0], [0, 0]], # [[0, 0], [3, -3], [4, -4], [0, 0]], # [[0, 0], [0, 0], [5, -5], [6, -6]], # ], # [ [[.1, -.1], [.2, -.2], [0, 0], [0, 0]], # [[0, 0], [.3, -.3], [.4, -.4], [0, 0]], # [[0, 0], [0, 0], [.5, -.5], [.6, -.6]], # ], # ], tensor_utils.skew_elements_right(tensor, -2)) self.assertAllClose( [ [ [[1, -1], [2, -2]], # [[3, -3], [4, -4]], # [[5, -5], [6, -6]], # [[0, 0], [0, 0]], # ], # [ [[0, 0], [0, 0]], # [[.1, -.1], [.2, -.2]], # [[.3, -.3], [.4, -.4]], # [[.5, -.5], [.6, -.6]], # ], # ], tensor_utils.skew_elements_right(tensor, 1))
def test_skew_elements_right_2d(self): tensor = tf.constant([ [1, 2, 3], # [4, 5, 6], # [7, 8, 9], # ]) self.assertAllEqual( [ [1, 2, 3, 0, 0], # [0, 4, 5, 6, 0], # [0, 0, 7, 8, 9], # ], tensor_utils.skew_elements_right(tensor, -1)) self.assertAllEqual( [ [1, 2, 3, -2, -2], # [-2, 4, 5, 6, -2], # [-2, -2, 7, 8, 9], # ], tensor_utils.skew_elements_right(tensor, 1, pad_value=-2)) with self.assertRaises(ValueError): tensor_utils.skew_elements_right(tensor, 0) self.assertAllEqual( [ [1, 2, 0, 0, 0, 0], # [0, 3, 4, 0, 0, 0], # [0, 0, 5, 6, 0, 0], # [0, 0, 0, 7, 8, 0], # [0, 0, 0, 0, 9, 10], # ], tensor_utils.skew_elements_right( [ [1, 2], # [3, 4], # [5, 6], # [7, 8], # [9, 10], # ], 1)) self.assertAllEqual( [ [1, 2, 3, 4, 0], # [0, 5, 6, 7, 8], # ], tensor_utils.skew_elements_right( [ [1, 2, 3, 4], # [5, 6, 7, 8], # ], 1)) self.assertAllEqual( [ [1, 2, 3], # ], tensor_utils.skew_elements_right( [ [1, 2, 3], # ], 1))