def test_compute_mask_indices_low_prob(self): # with these settings num_masked_spans=0.5, which means probabilistic rounding # ensures that in 5 out of 10 method calls, num_masked_spans=0, and in # the other 5 out of 10, cases num_masked_spans=1 n_trials = 100 batch_size = 4 sequence_length = 100 mask_prob = 0.05 mask_length = 10 count_dimensions_masked = 0 count_dimensions_not_masked = 0 for _ in range(n_trials): mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) mask = torch.from_numpy(mask).to(torch_device) num_masks = torch.sum(mask).item() if num_masks > 0: count_dimensions_masked += 1 else: count_dimensions_not_masked += 1 # as we test for at least 10 masked dimension and at least # 10 non-masked dimension, this test could fail with probability: # P(100 coin flips, at most 9 heads) = 1.66e-18 self.assertGreater(count_dimensions_masked, int(n_trials * 0.1)) self.assertGreater(count_dimensions_not_masked, int(n_trials * 0.1))
def test_compute_mask_indices(self): batch_size = 4 sequence_length = 60 mask_prob = 0.5 mask_length = 1 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) mask = torch.from_numpy(mask).to(torch_device) self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
def test_compute_mask_indices_overlap(self): batch_size = 4 sequence_length = 80 mask_prob = 0.5 mask_length = 4 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) mask = torch.from_numpy(mask).to(torch_device) # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal for batch_sum in mask.sum(axis=-1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_mask_indices_short_audio(self): batch_size = 4 sequence_length = 100 mask_prob = 0.05 mask_length = 10 attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device) # force one example to be heavily padded attention_mask[0, 5:] = 0 mask = _compute_mask_indices( (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask, min_masks=2 ) # make sure that non-padded examples cannot be padded self.assertFalse(mask[0][attention_mask[0].to(torch.bool).cpu()].any())
def test_compute_mask_indices_attn_mask_overlap(self): batch_size = 4 sequence_length = 80 mask_prob = 0.5 mask_length = 4 attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device) attention_mask[:2, sequence_length // 2 :] = 0 mask = _compute_mask_indices( (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask ) mask = torch.from_numpy(mask).to(torch_device) for batch_sum in mask.sum(axis=-1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)