Example #1
0
    def get_dataloader(self):
        vx = VOXFORGE(args.data_path,
                      langs=args.languages,
                      label_type="lang",
                      use_cache=args.use_cache,
                      use_precompute=args.use_precompute)
        if self.model_name == "resnet34_conv" or self.model_name == "resnet101_conv":
            T = tat.Compose([
                #tat.PadTrim(self.max_len),
                tat.MEL(n_mels=224),
                tat.BLC2CBL(),
                tvt.ToPILImage(),
                tvt.Resize((224, 224)),
                tvt.ToTensor(),
            ])
            TT = spl_transforms.LENC(vx.LABELS)
        elif self.model_name == "resnet34_mfcc":
            sr = 16000
            ws = 800
            hs = ws // 2
            n_fft = 512  # 256
            n_filterbanks = 26
            n_coefficients = 12
            low_mel_freq = 0
            high_freq_mel = (2595 * math.log10(1 + (sr / 2) / 700))
            mel_pts = torch.linspace(low_mel_freq, high_freq_mel,
                                     n_filterbanks + 2)  # sr = 16000
            hz_pts = torch.floor(700 * (torch.pow(10, mel_pts / 2595) - 1))
            bins = torch.floor((n_fft + 1) * hz_pts / sr)
            td = {
                "RfftPow": spl_transforms.RfftPow(n_fft),
                "FilterBanks": spl_transforms.FilterBanks(n_filterbanks, bins),
                "MFCC": spl_transforms.MFCC(n_filterbanks, n_coefficients),
            }

            T = tat.Compose([
                tat.Scale(),
                #tat.PadTrim(self.max_len, fill_value=1e-8),
                spl_transforms.Preemphasis(),
                spl_transforms.Sig2Features(ws, hs, td),
                spl_transforms.DummyDim(),
                tat.BLC2CBL(),
                tvt.ToPILImage(),
                tvt.Resize((224, 224)),
                tvt.ToTensor(),
            ])
            TT = spl_transforms.LENC(vx.LABELS)
        vx.transform = T
        vx.target_transform = TT
        if args.use_precompute:
            vx.load_precompute(args.model_name)
        dl = data.DataLoader(vx,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=True)
        return vx, dl
Example #2
0
    def test_mel(self):

        audio = self.sig.clone()
        audio = transforms.Scale()(audio)
        self.assertTrue(audio.dim() == 2)
        result = transforms.MEL()(audio)
        self.assertTrue(result.dim() == 3)
        result = transforms.BLC2CBL()(result)
        self.assertTrue(result.dim() == 3)

        repr_test = transforms.MEL()
        repr_test.__repr__()
        repr_test = transforms.BLC2CBL()
        repr_test.__repr__()
Example #3
0
    def test1(self):
        # Data
        vx = VOXFORGE(self.bdir, label_type="lang", use_cache=True)
        #vx.find_max_len()
        vx.maxlen = 150000
        T = tat.Compose([
                tat.PadTrim(vx.maxlen),
                tat.MEL(n_mels=224),
                tat.BLC2CBL(),
                tvt.ToPILImage(),
                tvt.Scale((224, 224)),
                tvt.ToTensor(),
            ])
        TT = spl_transforms.LENC(vx.LABELS)
        vx.transform = T
        vx.target_transform = TT
        dl = data.DataLoader(vx, batch_size = 25, shuffle=True)

        # Model and Loss
        model = models.resnet.resnet34(True)
        print(model)
        criterion = nn.CrossEntropyLoss()
        plist = nn.ParameterList()
        #plist.extend(list(model[0].parameters()))
        plist.extend(list(model[1].fc.parameters()))
        #plist.extend(list(model.parameters()))
        #optimizer = torch.optim.SGD(plist, lr=0.0001, momentum=0.9)
        optimizer = torch.optim.Adam(plist, lr=0.0001)

        train_losses = []
        valid_losses = []
        for i, (mb, tgts) in enumerate(dl):
            model.train()
            vx.set_split("train")
            mb, tgts = Variable(mb), Variable(tgts)
            model.zero_grad()
            out = model(mb)
            loss = criterion(out, tgts)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.data[0])
            print(loss.data[0])
            if i % 5 == 0:
                start = time.time()
                model.eval()
                vx.set_split("valid")
                running_validation_loss = 0
                correct = 0
                for mb_valid, tgts_valid in dl:
                    mb_valid, tgts_valid = Variable(mb_valid), Variable(tgts_valid)
                    out_valid = model(mb_valid)
                    loss_valid = criterion(out_valid, tgts_valid)
                    running_validation_loss += loss_valid.data[0]
                    correct += (out_valid.data.max(1)[1] == tgts_valid.data).sum()
                print_running_time(start)
                valid_losses.append((running_validation_loss, correct / len(vx)))
                print("loss: {}, acc: {}".format(running_validation_loss, correct / len(vx)))
            if i == 11: break
            vx.set_split("train")
Example #4
0
    def test_mel(self):

        audio = self.sig.clone()
        audio = transforms.Scale()(audio)
        self.assertTrue(len(audio.size()) == 2)
        result = transforms.MEL()(audio)
        self.assertTrue(len(result.size()) == 3)
        result = transforms.BLC2CBL()(result)
        self.assertTrue(len(result.size()) == 3)
Example #5
0
    def get_dataloader(self):
        usl = True if self.loss_criterion == "crossentropy" else False
        ds = AUDIOSET(self.data_path, dataset=self.args.dataset, noises_dir=self.noises_dir,
                      use_cache=False, num_samples=self.args.num_samples,
                      add_no_label=self.args.add_no_label, use_single_label=usl)
        if any(x in self.model_name for x in ["resnet34_conv", "resnet101_conv", "squeezenet"]):
            T = tat.Compose([
                    #tat.PadTrim(self.max_len, fill_value=1e-8),
                    mgc_transforms.SimpleTrim(self.max_len),
                    mgc_transforms.MEL(sr=16000, n_fft=600, hop_length=300, n_mels=self.args.freq_bands//2),
                    #mgc_transforms.Scale(),
                    mgc_transforms.BLC2CBL(),
                    mgc_transforms.Resize((self.args.freq_bands, self.args.freq_bands)),
                ])
        elif "_mfcc_librosa" in self.model_name:
            T = tat.Compose([
                    #tat.PadTrim(self.max_len, fill_value=1e-8),
                    mgc_transforms.SimpleTrim(self.max_len),
                    mgc_transforms.MFCC2(sr=16000, n_fft=600, hop_length=300, n_mfcc=12),
                    mgc_transforms.Scale(),
                    mgc_transforms.BLC2CBL(),
                    mgc_transforms.Resize((self.args.freq_bands, self.args.freq_bands)),
                ])
        elif "_mfcc" in self.model_name:
            sr = 16000
            ws = 800
            hs = ws // 2
            n_fft = 512 # 256
            n_filterbanks = 26
            n_coefficients = 12
            low_mel_freq = 0
            high_freq_mel = (2595 * math.log10(1 + (sr/2) / 700))
            mel_pts = torch.linspace(low_mel_freq, high_freq_mel, n_filterbanks + 2) # sr = 16000
            hz_pts = torch.floor(700 * (torch.pow(10,mel_pts / 2595) - 1))
            bins = torch.floor((n_fft + 1) * hz_pts / sr)
            td = {
                    "RfftPow": mgc_transforms.RfftPow(n_fft),
                    "FilterBanks": mgc_transforms.FilterBanks(n_filterbanks, bins),
                    "MFCC": mgc_transforms.MFCC(n_filterbanks, n_coefficients),
                 }

            T = tat.Compose([
                    #tat.PadTrim(self.max_len, fill_value=1e-8),
                    mgc_transforms.Preemphasis(),
                    mgc_transforms.SimpleTrim(self.max_len),
                    mgc_transforms.Sig2Features(ws, hs, td),
                    mgc_transforms.DummyDim(),
                    mgc_transforms.Scale(),
                    tat.BLC2CBL(),
                    mgc_transforms.Resize((self.args.freq_bands, self.args.freq_bands)),
                ])
        elif "attn" in self.model_name:
            T = tat.Compose([
                    mgc_transforms.SimpleTrim(self.max_len),
                    mgc_transforms.MEL(sr=16000, n_fft=600, hop_length=300, n_mels=self.args.freq_bands//2),
                    #mgc_transforms.Scale(),
                    mgc_transforms.SqueezeDim(2),
                    tat.LC2CL(),
                ])
        elif "bytenet" in self.model_name:
            #offset = 714 # make clips divisible by 224
            T = tat.Compose([
                    mgc_transforms.SimpleTrim(self.max_len),
                    #tat.PadTrim(self.max_len),
                    mgc_transforms.Scale(),
                    tat.LC2CL(),
                ])
        ds.transform = T
        if self.loss_criterion == "crossentropy":
            TT = mgc_transforms.XEntENC(ds.labels_dict)
            #TT = mgc_transforms.BinENC(ds.labels_dict, dtype=torch.int64)
        else:
            TT = mgc_transforms.BinENC(ds.labels_dict)
        ds.target_transform = TT
        ds.use_cache = self.use_cache
        if self.use_cache:
            ds.init_cache()
        if self.use_precompute:
            ds.load_precompute(self.model_name)
        dl = data.DataLoader(ds, batch_size=self.batch_size, drop_last=True,
                             num_workers=self.num_workers, collate_fn=bce_collate,
                             shuffle=True)
        if "attn" in self.model_name:
            dl.collate_fn = sort_collate
        return ds, dl