def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis): """mask_along_axis_iid should not alter original input Tensor Test is run 5 times to bound the probability of no masking occurring to 1e-10 See https://github.com/pytorch/audio/issues/1478 """ torch.random.manual_seed(42) for _ in range(5): specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device) specgrams_copy = specgrams.clone() F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) self.assertEqual(specgrams, specgrams_copy)
def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis): mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) other_axis = 2 if axis == 3 else 3 masked_columns = (mask_specgrams == mask_value).sum(other_axis) num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1) assert mask_specgrams.size() == specgrams.size() assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
def test_mask_along_axis_iid(mask_param, mask_value, axis): torch.random.manual_seed(42) specgrams = torch.randn(4, 2, 1025, 400) mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) other_axis = 2 if axis == 3 else 3 masked_columns = (mask_specgrams == mask_value).sum(other_axis) num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1) assert mask_specgrams.size() == specgrams.size() assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor: r""" Args: specgram (Tensor): Tensor of dimension (..., freq, time). mask_value (float): Value to assign to the masked columns. Returns: Tensor: Masked spectrogram of dimensions (..., freq, time). """ # if iid_masks flag marked and specgram has a batch dimension if self.iid_masks and specgram.dim() == 4: return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) else: return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
def func(tensor): mask_param = 100 mask_value = 30. axis = 2 return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)