예제 #1
0
    def test_mel(self):

        audio = self.sig.clone()
        audio = transforms.Scale()(audio)
        self.assertTrue(audio.dim() == 2)
        result = transforms.MEL()(audio)
        self.assertTrue(result.dim() == 3)
        result = transforms.BLC2CBL()(result)
        self.assertTrue(result.dim() == 3)

        repr_test = transforms.MEL()
        repr_test.__repr__()
        repr_test = transforms.BLC2CBL()
        repr_test.__repr__()
예제 #2
0
    def test1(self):
        # Data
        vx = VOXFORGE(self.bdir, label_type="lang", use_cache=True)
        #vx.find_max_len()
        vx.maxlen = 150000
        T = tat.Compose([
                tat.PadTrim(vx.maxlen),
                tat.MEL(n_mels=224),
                tat.BLC2CBL(),
                tvt.ToPILImage(),
                tvt.Scale((224, 224)),
                tvt.ToTensor(),
            ])
        TT = spl_transforms.LENC(vx.LABELS)
        vx.transform = T
        vx.target_transform = TT
        dl = data.DataLoader(vx, batch_size = 25, shuffle=True)

        # Model and Loss
        model = models.resnet.resnet34(True)
        print(model)
        criterion = nn.CrossEntropyLoss()
        plist = nn.ParameterList()
        #plist.extend(list(model[0].parameters()))
        plist.extend(list(model[1].fc.parameters()))
        #plist.extend(list(model.parameters()))
        #optimizer = torch.optim.SGD(plist, lr=0.0001, momentum=0.9)
        optimizer = torch.optim.Adam(plist, lr=0.0001)

        train_losses = []
        valid_losses = []
        for i, (mb, tgts) in enumerate(dl):
            model.train()
            vx.set_split("train")
            mb, tgts = Variable(mb), Variable(tgts)
            model.zero_grad()
            out = model(mb)
            loss = criterion(out, tgts)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.data[0])
            print(loss.data[0])
            if i % 5 == 0:
                start = time.time()
                model.eval()
                vx.set_split("valid")
                running_validation_loss = 0
                correct = 0
                for mb_valid, tgts_valid in dl:
                    mb_valid, tgts_valid = Variable(mb_valid), Variable(tgts_valid)
                    out_valid = model(mb_valid)
                    loss_valid = criterion(out_valid, tgts_valid)
                    running_validation_loss += loss_valid.data[0]
                    correct += (out_valid.data.max(1)[1] == tgts_valid.data).sum()
                print_running_time(start)
                valid_losses.append((running_validation_loss, correct / len(vx)))
                print("loss: {}, acc: {}".format(running_validation_loss, correct / len(vx)))
            if i == 11: break
            vx.set_split("train")
예제 #3
0
    def test_mel(self):

        audio = self.sig.clone()
        audio = transforms.Scale()(audio)
        self.assertTrue(len(audio.size()) == 2)
        result = transforms.MEL()(audio)
        self.assertTrue(len(result.size()) == 3)
        result = transforms.BLC2CBL()(result)
        self.assertTrue(len(result.size()) == 3)
예제 #4
0
    def get_dataloader(self):
        vx = VOXFORGE(args.data_path,
                      langs=args.languages,
                      label_type="lang",
                      use_cache=args.use_cache,
                      use_precompute=args.use_precompute)
        if self.model_name == "resnet34_conv" or self.model_name == "resnet101_conv":
            T = tat.Compose([
                #tat.PadTrim(self.max_len),
                tat.MEL(n_mels=224),
                tat.BLC2CBL(),
                tvt.ToPILImage(),
                tvt.Resize((224, 224)),
                tvt.ToTensor(),
            ])
            TT = spl_transforms.LENC(vx.LABELS)
        elif self.model_name == "resnet34_mfcc":
            sr = 16000
            ws = 800
            hs = ws // 2
            n_fft = 512  # 256
            n_filterbanks = 26
            n_coefficients = 12
            low_mel_freq = 0
            high_freq_mel = (2595 * math.log10(1 + (sr / 2) / 700))
            mel_pts = torch.linspace(low_mel_freq, high_freq_mel,
                                     n_filterbanks + 2)  # sr = 16000
            hz_pts = torch.floor(700 * (torch.pow(10, mel_pts / 2595) - 1))
            bins = torch.floor((n_fft + 1) * hz_pts / sr)
            td = {
                "RfftPow": spl_transforms.RfftPow(n_fft),
                "FilterBanks": spl_transforms.FilterBanks(n_filterbanks, bins),
                "MFCC": spl_transforms.MFCC(n_filterbanks, n_coefficients),
            }

            T = tat.Compose([
                tat.Scale(),
                #tat.PadTrim(self.max_len, fill_value=1e-8),
                spl_transforms.Preemphasis(),
                spl_transforms.Sig2Features(ws, hs, td),
                spl_transforms.DummyDim(),
                tat.BLC2CBL(),
                tvt.ToPILImage(),
                tvt.Resize((224, 224)),
                tvt.ToTensor(),
            ])
            TT = spl_transforms.LENC(vx.LABELS)
        vx.transform = T
        vx.target_transform = TT
        if args.use_precompute:
            vx.load_precompute(args.model_name)
        dl = data.DataLoader(vx,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=True)
        return vx, dl