def get_dataloader(self): vx = VOXFORGE(args.data_path, langs=args.languages, label_type="lang", use_cache=args.use_cache, use_precompute=args.use_precompute) if self.model_name == "resnet34_conv" or self.model_name == "resnet101_conv": T = tat.Compose([ #tat.PadTrim(self.max_len), tat.MEL(n_mels=224), tat.BLC2CBL(), tvt.ToPILImage(), tvt.Resize((224, 224)), tvt.ToTensor(), ]) TT = spl_transforms.LENC(vx.LABELS) elif self.model_name == "resnet34_mfcc": sr = 16000 ws = 800 hs = ws // 2 n_fft = 512 # 256 n_filterbanks = 26 n_coefficients = 12 low_mel_freq = 0 high_freq_mel = (2595 * math.log10(1 + (sr / 2) / 700)) mel_pts = torch.linspace(low_mel_freq, high_freq_mel, n_filterbanks + 2) # sr = 16000 hz_pts = torch.floor(700 * (torch.pow(10, mel_pts / 2595) - 1)) bins = torch.floor((n_fft + 1) * hz_pts / sr) td = { "RfftPow": spl_transforms.RfftPow(n_fft), "FilterBanks": spl_transforms.FilterBanks(n_filterbanks, bins), "MFCC": spl_transforms.MFCC(n_filterbanks, n_coefficients), } T = tat.Compose([ tat.Scale(), #tat.PadTrim(self.max_len, fill_value=1e-8), spl_transforms.Preemphasis(), spl_transforms.Sig2Features(ws, hs, td), spl_transforms.DummyDim(), tat.BLC2CBL(), tvt.ToPILImage(), tvt.Resize((224, 224)), tvt.ToTensor(), ]) TT = spl_transforms.LENC(vx.LABELS) vx.transform = T vx.target_transform = TT if args.use_precompute: vx.load_precompute(args.model_name) dl = data.DataLoader(vx, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) return vx, dl
def test_mel(self): audio = self.sig.clone() audio = transforms.Scale()(audio) self.assertTrue(audio.dim() == 2) result = transforms.MEL()(audio) self.assertTrue(result.dim() == 3) result = transforms.BLC2CBL()(result) self.assertTrue(result.dim() == 3) repr_test = transforms.MEL() repr_test.__repr__() repr_test = transforms.BLC2CBL() repr_test.__repr__()
def test1(self): # Data vx = VOXFORGE(self.bdir, label_type="lang", use_cache=True) #vx.find_max_len() vx.maxlen = 150000 T = tat.Compose([ tat.PadTrim(vx.maxlen), tat.MEL(n_mels=224), tat.BLC2CBL(), tvt.ToPILImage(), tvt.Scale((224, 224)), tvt.ToTensor(), ]) TT = spl_transforms.LENC(vx.LABELS) vx.transform = T vx.target_transform = TT dl = data.DataLoader(vx, batch_size = 25, shuffle=True) # Model and Loss model = models.resnet.resnet34(True) print(model) criterion = nn.CrossEntropyLoss() plist = nn.ParameterList() #plist.extend(list(model[0].parameters())) plist.extend(list(model[1].fc.parameters())) #plist.extend(list(model.parameters())) #optimizer = torch.optim.SGD(plist, lr=0.0001, momentum=0.9) optimizer = torch.optim.Adam(plist, lr=0.0001) train_losses = [] valid_losses = [] for i, (mb, tgts) in enumerate(dl): model.train() vx.set_split("train") mb, tgts = Variable(mb), Variable(tgts) model.zero_grad() out = model(mb) loss = criterion(out, tgts) loss.backward() optimizer.step() train_losses.append(loss.data[0]) print(loss.data[0]) if i % 5 == 0: start = time.time() model.eval() vx.set_split("valid") running_validation_loss = 0 correct = 0 for mb_valid, tgts_valid in dl: mb_valid, tgts_valid = Variable(mb_valid), Variable(tgts_valid) out_valid = model(mb_valid) loss_valid = criterion(out_valid, tgts_valid) running_validation_loss += loss_valid.data[0] correct += (out_valid.data.max(1)[1] == tgts_valid.data).sum() print_running_time(start) valid_losses.append((running_validation_loss, correct / len(vx))) print("loss: {}, acc: {}".format(running_validation_loss, correct / len(vx))) if i == 11: break vx.set_split("train")
def test_mel(self): audio = self.sig.clone() audio = transforms.Scale()(audio) self.assertTrue(len(audio.size()) == 2) result = transforms.MEL()(audio) self.assertTrue(len(result.size()) == 3) result = transforms.BLC2CBL()(result) self.assertTrue(len(result.size()) == 3)
def get_dataloader(self): usl = True if self.loss_criterion == "crossentropy" else False ds = AUDIOSET(self.data_path, dataset=self.args.dataset, noises_dir=self.noises_dir, use_cache=False, num_samples=self.args.num_samples, add_no_label=self.args.add_no_label, use_single_label=usl) if any(x in self.model_name for x in ["resnet34_conv", "resnet101_conv", "squeezenet"]): T = tat.Compose([ #tat.PadTrim(self.max_len, fill_value=1e-8), mgc_transforms.SimpleTrim(self.max_len), mgc_transforms.MEL(sr=16000, n_fft=600, hop_length=300, n_mels=self.args.freq_bands//2), #mgc_transforms.Scale(), mgc_transforms.BLC2CBL(), mgc_transforms.Resize((self.args.freq_bands, self.args.freq_bands)), ]) elif "_mfcc_librosa" in self.model_name: T = tat.Compose([ #tat.PadTrim(self.max_len, fill_value=1e-8), mgc_transforms.SimpleTrim(self.max_len), mgc_transforms.MFCC2(sr=16000, n_fft=600, hop_length=300, n_mfcc=12), mgc_transforms.Scale(), mgc_transforms.BLC2CBL(), mgc_transforms.Resize((self.args.freq_bands, self.args.freq_bands)), ]) elif "_mfcc" in self.model_name: sr = 16000 ws = 800 hs = ws // 2 n_fft = 512 # 256 n_filterbanks = 26 n_coefficients = 12 low_mel_freq = 0 high_freq_mel = (2595 * math.log10(1 + (sr/2) / 700)) mel_pts = torch.linspace(low_mel_freq, high_freq_mel, n_filterbanks + 2) # sr = 16000 hz_pts = torch.floor(700 * (torch.pow(10,mel_pts / 2595) - 1)) bins = torch.floor((n_fft + 1) * hz_pts / sr) td = { "RfftPow": mgc_transforms.RfftPow(n_fft), "FilterBanks": mgc_transforms.FilterBanks(n_filterbanks, bins), "MFCC": mgc_transforms.MFCC(n_filterbanks, n_coefficients), } T = tat.Compose([ #tat.PadTrim(self.max_len, fill_value=1e-8), mgc_transforms.Preemphasis(), mgc_transforms.SimpleTrim(self.max_len), mgc_transforms.Sig2Features(ws, hs, td), mgc_transforms.DummyDim(), mgc_transforms.Scale(), tat.BLC2CBL(), mgc_transforms.Resize((self.args.freq_bands, self.args.freq_bands)), ]) elif "attn" in self.model_name: T = tat.Compose([ mgc_transforms.SimpleTrim(self.max_len), mgc_transforms.MEL(sr=16000, n_fft=600, hop_length=300, n_mels=self.args.freq_bands//2), #mgc_transforms.Scale(), mgc_transforms.SqueezeDim(2), tat.LC2CL(), ]) elif "bytenet" in self.model_name: #offset = 714 # make clips divisible by 224 T = tat.Compose([ mgc_transforms.SimpleTrim(self.max_len), #tat.PadTrim(self.max_len), mgc_transforms.Scale(), tat.LC2CL(), ]) ds.transform = T if self.loss_criterion == "crossentropy": TT = mgc_transforms.XEntENC(ds.labels_dict) #TT = mgc_transforms.BinENC(ds.labels_dict, dtype=torch.int64) else: TT = mgc_transforms.BinENC(ds.labels_dict) ds.target_transform = TT ds.use_cache = self.use_cache if self.use_cache: ds.init_cache() if self.use_precompute: ds.load_precompute(self.model_name) dl = data.DataLoader(ds, batch_size=self.batch_size, drop_last=True, num_workers=self.num_workers, collate_fn=bce_collate, shuffle=True) if "attn" in self.model_name: dl.collate_fn = sort_collate return ds, dl