def test_batch_mulaw(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 # Single then transform then batch waveform_encoded = transforms.MuLawEncoding()(waveform) expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1) # Batch then transform waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1) computed = transforms.MuLawEncoding()(waveform_batched) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) # Single then transform then batch waveform_decoded = transforms.MuLawDecoding()(waveform_encoded) expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1) # Batch then transform computed = transforms.MuLawDecoding()(computed) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected))
def test_mu_law_companding(self): quantization_channels = 256 waveform = self.waveform.clone() waveform /= torch.abs(waveform).max() self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.) waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform) self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels) waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_mu_law_companding(self): quantization_channels = 256 waveform = self.waveform.clone() if not waveform.is_floating_point(): waveform = waveform.to(torch.get_default_dtype()) waveform /= torch.abs(waveform).max() self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.) waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform) self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels) waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_MuLawDecoding(self): tensor = torch.rand((1, 10)) self._assert_consistency(T.MuLawDecoding(), tensor)