예제 #1
0
    def test_lc2cl(self):

        audio = self.sig.clone()
        result = transforms.LC2CL()(audio)
        self.assertTrue(result.size()[::-1] == audio.size())

        repr_test = transforms.LC2CL()
        self.assertTrue(repr_test.__repr__())
예제 #2
0
 def test_mel2(self):
     audio_orig = self.sig.clone()  # (16000, 1)
     audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
     audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
     spectrogram_torch = transforms.MEL2()(audio_scaled)  # (1, 319, 40)
     self.assertTrue(spectrogram_torch.dim() == 3)
     self.assertTrue(spectrogram_torch.max() <= 0.)
def get_loader(config, data_dir):
    root = os.path.join(os.path.abspath(os.curdir), data_dir)
    print('-- Loading audios')
    dataset = AudioFolder(root=root,
                          transform=transforms.Compose([
                              transforms.PadTrim(133623, 0),
                              transforms.LC2CL()
                          ]))
    loader = DataLoader(dataset=dataset,
                        batch_size=config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers)
    return loader
예제 #4
0
    def test_mel2(self):
        top_db = 80.
        s2db = transforms.SpectrogramToDB("power", top_db)

        audio_orig = self.sig.clone()  # (16000, 1)
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
        mel_transform = transforms.MelSpectrogram()
        # check defaults
        spectrogram_torch = s2db(mel_transform(audio_scaled))  # (1, 319, 40)
        self.assertTrue(spectrogram_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
        # check options
        kwargs = {
            "window": torch.hamming_window,
            "pad": 10,
            "ws": 500,
            "hop": 125,
            "n_fft": 800,
            "n_mels": 50
        }
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
        spectrogram2_torch = s2db(mel_transform2(audio_scaled))  # (1, 506, 50)
        self.assertTrue(spectrogram2_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
        spectrogram_stereo = s2db(mel_transform(x_stereo))
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
        # check filterbank matrix creation
        fb_matrix_transform = transforms.MelScale(n_mels=100,
                                                  sr=16000,
                                                  f_max=None,
                                                  f_min=0.,
                                                  n_stft=400)
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
예제 #5
0
 def test_mel2(self):
     audio_orig = self.sig.clone()  # (16000, 1)
     audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
     audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
     mel_transform = transforms.MEL2()
     # check defaults
     spectrogram_torch = mel_transform(audio_scaled)  # (1, 319, 40)
     self.assertTrue(spectrogram_torch.dim() == 3)
     self.assertTrue(spectrogram_torch.le(0.).all())
     self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
     # check correctness of filterbank conversion matrix
     self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
     self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
     # check options
     mel_transform2 = transforms.MEL2(window=torch.hamming_window,
                                      pad=10,
                                      ws=500,
                                      hop=125,
                                      n_fft=800,
                                      n_mels=50)
     spectrogram2_torch = mel_transform2(audio_scaled)  # (1, 506, 50)
     self.assertTrue(spectrogram2_torch.dim() == 3)
     self.assertTrue(spectrogram2_torch.le(0.).all())
     self.assertTrue(spectrogram2_torch.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
     self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
     self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
     # check on multi-channel audio
     x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
     spectrogram_stereo = mel_transform(x_stereo)
     self.assertTrue(spectrogram_stereo.dim() == 3)
     self.assertTrue(spectrogram_stereo.size(0) == 2)
     self.assertTrue(spectrogram_stereo.le(0.).all())
     self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
     # check filterbank matrix creation
     fb_matrix_transform = transforms.F2M(n_mels=100,
                                          sr=16000,
                                          f_max=None,
                                          f_min=0.,
                                          n_stft=400)
     self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
     self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
     self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
예제 #6
0
 def test_mel2(self):
     audio_orig = self.sig.clone()  # (16000, 1)
     audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
     audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
     mel_transform = transforms.MEL2(window=torch.hamming_window, pad=10)
     spectrogram_torch = mel_transform(audio_scaled)  # (1, 319, 40)
     self.assertTrue(spectrogram_torch.dim() == 3)
     self.assertTrue(spectrogram_torch.le(0.).all())
     self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
     # load stereo file
     x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
     spectrogram_stereo = mel_transform(x_stereo)
     self.assertTrue(spectrogram_stereo.dim() == 3)
     self.assertTrue(spectrogram_stereo.size(0) == 2)
     self.assertTrue(spectrogram_stereo.le(0.).all())
     self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
     self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
예제 #7
0
    def test_mfcc(self):
        audio_orig = self.sig.clone()
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
        mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
        torch_mfcc = mfcc_transform(audio_scaled)
        self.assertTrue(torch_mfcc.dim() == 3)
        self.assertTrue(torch_mfcc.shape[2] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[1] == 321)
        # check melkwargs are passed through
        melkwargs = {'ws': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate,
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
        torch_mfcc2 = mfcc_transform2(audio_scaled)
        self.assertTrue(torch_mfcc2.shape[1] == 641)

        # check norms work correctly
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate,
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)

        norm_check = torch_mfcc.clone()
        norm_check[:, :, 0] *= math.sqrt(n_mels) * 2
        norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
예제 #8
0
    def get_dataloader(self):
        usl = True if self.loss_criterion == "crossentropy" else False
        ds = AUDIOSET(self.data_path, dataset=self.args.dataset, noises_dir=self.noises_dir,
                      use_cache=False, num_samples=self.args.num_samples,
                      add_no_label=self.args.add_no_label, use_single_label=usl)
        if any(x in self.model_name for x in ["resnet34_conv", "resnet101_conv", "squeezenet"]):
            T = tat.Compose([
                    #tat.PadTrim(self.max_len, fill_value=1e-8),
                    mgc_transforms.SimpleTrim(self.max_len),
                    mgc_transforms.MEL(sr=16000, n_fft=600, hop_length=300, n_mels=self.args.freq_bands//2),
                    #mgc_transforms.Scale(),
                    mgc_transforms.BLC2CBL(),
                    mgc_transforms.Resize((self.args.freq_bands, self.args.freq_bands)),
                ])
        elif "_mfcc_librosa" in self.model_name:
            T = tat.Compose([
                    #tat.PadTrim(self.max_len, fill_value=1e-8),
                    mgc_transforms.SimpleTrim(self.max_len),
                    mgc_transforms.MFCC2(sr=16000, n_fft=600, hop_length=300, n_mfcc=12),
                    mgc_transforms.Scale(),
                    mgc_transforms.BLC2CBL(),
                    mgc_transforms.Resize((self.args.freq_bands, self.args.freq_bands)),
                ])
        elif "_mfcc" in self.model_name:
            sr = 16000
            ws = 800
            hs = ws // 2
            n_fft = 512 # 256
            n_filterbanks = 26
            n_coefficients = 12
            low_mel_freq = 0
            high_freq_mel = (2595 * math.log10(1 + (sr/2) / 700))
            mel_pts = torch.linspace(low_mel_freq, high_freq_mel, n_filterbanks + 2) # sr = 16000
            hz_pts = torch.floor(700 * (torch.pow(10,mel_pts / 2595) - 1))
            bins = torch.floor((n_fft + 1) * hz_pts / sr)
            td = {
                    "RfftPow": mgc_transforms.RfftPow(n_fft),
                    "FilterBanks": mgc_transforms.FilterBanks(n_filterbanks, bins),
                    "MFCC": mgc_transforms.MFCC(n_filterbanks, n_coefficients),
                 }

            T = tat.Compose([
                    #tat.PadTrim(self.max_len, fill_value=1e-8),
                    mgc_transforms.Preemphasis(),
                    mgc_transforms.SimpleTrim(self.max_len),
                    mgc_transforms.Sig2Features(ws, hs, td),
                    mgc_transforms.DummyDim(),
                    mgc_transforms.Scale(),
                    tat.BLC2CBL(),
                    mgc_transforms.Resize((self.args.freq_bands, self.args.freq_bands)),
                ])
        elif "attn" in self.model_name:
            T = tat.Compose([
                    mgc_transforms.SimpleTrim(self.max_len),
                    mgc_transforms.MEL(sr=16000, n_fft=600, hop_length=300, n_mels=self.args.freq_bands//2),
                    #mgc_transforms.Scale(),
                    mgc_transforms.SqueezeDim(2),
                    tat.LC2CL(),
                ])
        elif "bytenet" in self.model_name:
            #offset = 714 # make clips divisible by 224
            T = tat.Compose([
                    mgc_transforms.SimpleTrim(self.max_len),
                    #tat.PadTrim(self.max_len),
                    mgc_transforms.Scale(),
                    tat.LC2CL(),
                ])
        ds.transform = T
        if self.loss_criterion == "crossentropy":
            TT = mgc_transforms.XEntENC(ds.labels_dict)
            #TT = mgc_transforms.BinENC(ds.labels_dict, dtype=torch.int64)
        else:
            TT = mgc_transforms.BinENC(ds.labels_dict)
        ds.target_transform = TT
        ds.use_cache = self.use_cache
        if self.use_cache:
            ds.init_cache()
        if self.use_precompute:
            ds.load_precompute(self.model_name)
        dl = data.DataLoader(ds, batch_size=self.batch_size, drop_last=True,
                             num_workers=self.num_workers, collate_fn=bce_collate,
                             shuffle=True)
        if "attn" in self.model_name:
            dl.collate_fn = sort_collate
        return ds, dl