def test_compute_mask_indices_overlap(self): batch_size = 4 sequence_length = 80 mask_prob = 0.5 mask_length = 4 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal for batch_sum in tf.reduce_sum(mask, -1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_mask_indices(self): batch_size = 4 sequence_length = 60 mask_prob = 0.5 mask_length = 1 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) self.assertListEqual( tf.reduce_sum(mask, -1).numpy().tolist(), [mask_prob * sequence_length for _ in range(batch_size)] )