예제 #1
0
 def __init__(self,
              input_sr,
              output_sr=None,
              melspec_buckets=80,
              hop_length=256,
              n_fft=1024,
              cut_silence=False):
     """
     The parameters are by default set up to do well
     on a 16kHz signal. A different frequency may
     require different hop_length and n_fft (e.g.
     doubling frequency --> doubling hop_length and
     doubling n_fft)
     """
     self.cut_silence = cut_silence
     self.sr = input_sr
     self.new_sr = output_sr
     self.hop_length = hop_length
     self.n_fft = n_fft
     self.mel_buckets = melspec_buckets
     self.vad = VoiceActivityDetection(
         sample_rate=input_sr
     )  # This needs heavy tweaking, depending of the data
     self.mu_encode = MuLawEncoding()
     self.mu_decode = MuLawDecoding()
     self.meter = pyln.Meter(input_sr)
     self.final_sr = input_sr
     if output_sr is not None and output_sr != input_sr:
         self.resample = Resample(orig_freq=input_sr, new_freq=output_sr)
         self.final_sr = output_sr
     else:
         self.resample = lambda x: x
예제 #2
0
def train(epoch, train_loader, model, device, optimizer, error, log_freq=50):
    model.train()  #don't forget to switch between train and eval!

    running_loss = 0.0  #more accurate representation of current loss than loss.item()
    running_correct = 0.0

    for i, (mels, wavs) in enumerate(tqdm(train_loader)):
        mels, wavs = mels.to(device), wavs.to(device)
        inp_wavs = MuLawEncoding()(wavs).float()
        targets = torch.cat(
            [wavs[:, :, 1:],
             torch.zeros(wavs.shape[0], 1, 1).to(device)],
            dim=2)
        targets = MuLawEncoding()(targets.squeeze())

        optimizer.zero_grad()

        outputs = model(inp_wavs, mels)

        loss = error(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        pred = outputs.data.max(1, keepdim=True)[1]
        running_correct += pred.eq(targets.data.view_as(pred)).cpu().sum()

        if (i + 1) % log_freq == 0:
            print(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'
                .format(
                    epoch, (i + 1) * mels.shape[0], len(train_loader.dataset),
                    100. * (i + 1) / len(train_loader),
                    running_loss / log_freq, running_correct / log_freq /
                    wavs.shape[0] / wavs.shape[-1]))

            wandb.log({
                "Loss":
                running_loss / log_freq,
                "Accuracy":
                running_correct / log_freq / wavs.shape[0] / wavs.shape[-1]
            })

            running_loss = 0.0
            running_correct = 0.0
예제 #3
0
    def raw_collate(batch):

        pad = (args.kernel_size - 1) // 2

        # input waveform length
        wave_length = args.hop_length * args.seq_len_factor
        # input spectrogram length
        spec_length = args.seq_len_factor + pad * 2

        # max start postion in spectrogram
        max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch]

        # random start postion in spectrogram
        spec_offsets = [random.randint(0, offset) for offset in max_offsets]
        # random start postion in waveform
        wave_offsets = [(offset + pad) * args.hop_length
                        for offset in spec_offsets]

        waveform_combine = [
            x[0][wave_offsets[i]:wave_offsets[i] + wave_length + 1]
            for i, x in enumerate(batch)
        ]
        specgram = [
            x[1][:, spec_offsets[i]:spec_offsets[i] + spec_length]
            for i, x in enumerate(batch)
        ]

        specgram = torch.stack(specgram)
        waveform_combine = torch.stack(waveform_combine)

        waveform = waveform_combine[:, :wave_length]
        target = waveform_combine[:, 1:]

        # waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy'
        if args.loss == "crossentropy":

            if args.mulaw:
                mulaw_encode = MuLawEncoding(2**args.n_bits)
                waveform = mulaw_encode(waveform)
                target = mulaw_encode(target)

                waveform = bits_to_normalized_waveform(waveform, args.n_bits)

            else:
                target = normalized_waveform_to_bits(target, args.n_bits)

        return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(
            1)
예제 #4
0
batch_size = 4
check = True
KL = False
saveevery = 10
MSE = True
LR = 1e-3
plot = False
data = []
for d in listdir('audio'):
    print('Loading ', d)
    for i, s in enumerate(listdir('audio/' + d)):
        if i == samples or len(data) == 1e10:
            break
        info = torch.load('audio/' + d + '/' + s)
        info['pitch'] = float(len(info['pitch']))
        info['audio'] = MuLawEncoding()(
            info['audio']).type('torch.FloatTensor')
        info['audio'] = (info['audio'] - torch.min(info['audio'])) / (
            torch.max(info['audio']) - torch.min(info['audio']))
        data.append(info)
batches = int(len(data) / batch_size)


class MyDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.samples = data

    def __getitem__(self, idx):
        return self.samples[idx]

    def __len__(self):
        return len(self.samples)
예제 #5
0
    def __init__(self,
                 path,
                 sequence_length=20000,
                 total_dilation=0,
                 overwrite=True,
                 transform=MuLawEncoding(),
                 plot=False,
                 shift=1,
                 frontend_config=None,
                 som=None):
        """ The whole dataset is saved as a Pytorch tensor inside the directory with the name of the directory. If the
        file already exists the audio files are not read in again, unless the override option is set. Samples that are
        shorter than the given sequence length are omitted. 
        The SOM-encoded data is saved in self.data, and the mu-law & one-hot encoded data in self.data_out.
        If frontend_config=som=None, the mu-law encoded data is transferred to self_data and self.data_out set to [].
        """

        super(CustomDataset, self).__init__()
        self.data = []
        self.data_out = []
        self.sequence_length = sequence_length
        self.total_dilation = total_dilation
        self.shift = shift
        self.path = path

        dir_name = os.path.basename(os.path.normpath(self.path))

        # Check for existence of preprocessed file
        if (frontend_config is not None and som is not None and os.path.isfile(os.path.join(path, '{}.pt'.format(dir_name)))) \
            and os.path.isfile(os.path.join(path, '{}_out.pt'.format(dir_name))) and not overwrite:
            self.data = torch.load(
                os.path.join(os.path.join(path, '{}.pt'.format(dir_name))))
            self.data_out = torch.load(
                os.path.join(os.path.join(path, '{}_out.pt'.format(dir_name))))
            print('Both datasets loaded.')
        elif frontend_config is None and os.path.isfile(
                os.path.join(path,
                             '{}.pt'.format(dir_name))) and not overwrite:
            self.data = torch.load(
                os.path.join(os.path.join(path, '{}.pt'.format(dir_name))))
            print('Mu-law one hot dataset loaded.')

        else:
            #print((frontend_config is None), (som is None), os.path.isfile(os.path.join(path, '{}.pt'.format(dir_name))))
            #print(os.listdir(path))
            num_files = 0

            # Find and add all audio files in directory
            for _, _, fnames in os.walk(path):
                for fname in fnames:
                    if (num_files % 1) == 0:
                        print("\rFound {} sequences.".format(num_files),
                              end='')

                    if fname.lower().endswith(('mp3', 'wav')):

                        waveform, sample_rate = ta.load(
                            os.path.join(path, fname))
                        #print(waveform, waveform.size())

                        if waveform.shape[
                                1] > sequence_length:  #had to move check to beginning because after transformation with MFCC it is always too short

                            if plot:
                                plt.plot(waveform.t().numpy())
                                plt.show()

                            dout = transform(waveform).float()

                            # Remove silence from beginning of track
                            #transformed = remove_silence_start_end(transformed)
                            #NEEDS TO BE IMPLEMENTED FOR NON MU DATA

                            dout = one_hot(dout[0, :].int().long(),
                                           256).float().T
                            self.data_out.append(dout)

                            if frontend_config is not None:
                                transformed = auditory_frontend(
                                    waveform, sample_rate, frontend_config)

                                if som is not None:
                                    transformed = som.transform_mfcc_seq(
                                        transformed, transform_onehot=True)

                                self.data.append(transformed.float())

                        num_files += 1

            if self.data and self.data_out:
                torch.save(self.data,
                           os.path.join(path, '{}.pt'.format(dir_name)))
                torch.save(self.data_out,
                           os.path.join(path, '{}_out.pt'.format(dir_name)))
                #print(self.data[0], self.data[0].shape)
                print('\nBoth datasets saved.')

            if not self.data and self.data_out:
                self.data = self.data_out
                torch.save(self.data,
                           os.path.join(path, '{}.pt'.format(dir_name)))
                print('\nMu-law one-hot dataset saved.')
                self.data_out = []
예제 #6
0
def encoder(quantization_channels):
    return MuLawEncoding(quantization_channels)