def __init__(self, input_sr, output_sr=None, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False): """ The parameters are by default set up to do well on a 16kHz signal. A different frequency may require different hop_length and n_fft (e.g. doubling frequency --> doubling hop_length and doubling n_fft) """ self.cut_silence = cut_silence self.sr = input_sr self.new_sr = output_sr self.hop_length = hop_length self.n_fft = n_fft self.mel_buckets = melspec_buckets self.vad = VoiceActivityDetection( sample_rate=input_sr ) # This needs heavy tweaking, depending of the data self.mu_encode = MuLawEncoding() self.mu_decode = MuLawDecoding() self.meter = pyln.Meter(input_sr) self.final_sr = input_sr if output_sr is not None and output_sr != input_sr: self.resample = Resample(orig_freq=input_sr, new_freq=output_sr) self.final_sr = output_sr else: self.resample = lambda x: x
def train(epoch, train_loader, model, device, optimizer, error, log_freq=50): model.train() #don't forget to switch between train and eval! running_loss = 0.0 #more accurate representation of current loss than loss.item() running_correct = 0.0 for i, (mels, wavs) in enumerate(tqdm(train_loader)): mels, wavs = mels.to(device), wavs.to(device) inp_wavs = MuLawEncoding()(wavs).float() targets = torch.cat( [wavs[:, :, 1:], torch.zeros(wavs.shape[0], 1, 1).to(device)], dim=2) targets = MuLawEncoding()(targets.squeeze()) optimizer.zero_grad() outputs = model(inp_wavs, mels) loss = error(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() pred = outputs.data.max(1, keepdim=True)[1] running_correct += pred.eq(targets.data.view_as(pred)).cpu().sum() if (i + 1) % log_freq == 0: print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}' .format( epoch, (i + 1) * mels.shape[0], len(train_loader.dataset), 100. * (i + 1) / len(train_loader), running_loss / log_freq, running_correct / log_freq / wavs.shape[0] / wavs.shape[-1])) wandb.log({ "Loss": running_loss / log_freq, "Accuracy": running_correct / log_freq / wavs.shape[0] / wavs.shape[-1] }) running_loss = 0.0 running_correct = 0.0
def raw_collate(batch): pad = (args.kernel_size - 1) // 2 # input waveform length wave_length = args.hop_length * args.seq_len_factor # input spectrogram length spec_length = args.seq_len_factor + pad * 2 # max start postion in spectrogram max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch] # random start postion in spectrogram spec_offsets = [random.randint(0, offset) for offset in max_offsets] # random start postion in waveform wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets] waveform_combine = [ x[0][wave_offsets[i]:wave_offsets[i] + wave_length + 1] for i, x in enumerate(batch) ] specgram = [ x[1][:, spec_offsets[i]:spec_offsets[i] + spec_length] for i, x in enumerate(batch) ] specgram = torch.stack(specgram) waveform_combine = torch.stack(waveform_combine) waveform = waveform_combine[:, :wave_length] target = waveform_combine[:, 1:] # waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy' if args.loss == "crossentropy": if args.mulaw: mulaw_encode = MuLawEncoding(2**args.n_bits) waveform = mulaw_encode(waveform) target = mulaw_encode(target) waveform = bits_to_normalized_waveform(waveform, args.n_bits) else: target = normalized_waveform_to_bits(target, args.n_bits) return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze( 1)
batch_size = 4 check = True KL = False saveevery = 10 MSE = True LR = 1e-3 plot = False data = [] for d in listdir('audio'): print('Loading ', d) for i, s in enumerate(listdir('audio/' + d)): if i == samples or len(data) == 1e10: break info = torch.load('audio/' + d + '/' + s) info['pitch'] = float(len(info['pitch'])) info['audio'] = MuLawEncoding()( info['audio']).type('torch.FloatTensor') info['audio'] = (info['audio'] - torch.min(info['audio'])) / ( torch.max(info['audio']) - torch.min(info['audio'])) data.append(info) batches = int(len(data) / batch_size) class MyDataset(torch.utils.data.Dataset): def __init__(self): self.samples = data def __getitem__(self, idx): return self.samples[idx] def __len__(self): return len(self.samples)
def __init__(self, path, sequence_length=20000, total_dilation=0, overwrite=True, transform=MuLawEncoding(), plot=False, shift=1, frontend_config=None, som=None): """ The whole dataset is saved as a Pytorch tensor inside the directory with the name of the directory. If the file already exists the audio files are not read in again, unless the override option is set. Samples that are shorter than the given sequence length are omitted. The SOM-encoded data is saved in self.data, and the mu-law & one-hot encoded data in self.data_out. If frontend_config=som=None, the mu-law encoded data is transferred to self_data and self.data_out set to []. """ super(CustomDataset, self).__init__() self.data = [] self.data_out = [] self.sequence_length = sequence_length self.total_dilation = total_dilation self.shift = shift self.path = path dir_name = os.path.basename(os.path.normpath(self.path)) # Check for existence of preprocessed file if (frontend_config is not None and som is not None and os.path.isfile(os.path.join(path, '{}.pt'.format(dir_name)))) \ and os.path.isfile(os.path.join(path, '{}_out.pt'.format(dir_name))) and not overwrite: self.data = torch.load( os.path.join(os.path.join(path, '{}.pt'.format(dir_name)))) self.data_out = torch.load( os.path.join(os.path.join(path, '{}_out.pt'.format(dir_name)))) print('Both datasets loaded.') elif frontend_config is None and os.path.isfile( os.path.join(path, '{}.pt'.format(dir_name))) and not overwrite: self.data = torch.load( os.path.join(os.path.join(path, '{}.pt'.format(dir_name)))) print('Mu-law one hot dataset loaded.') else: #print((frontend_config is None), (som is None), os.path.isfile(os.path.join(path, '{}.pt'.format(dir_name)))) #print(os.listdir(path)) num_files = 0 # Find and add all audio files in directory for _, _, fnames in os.walk(path): for fname in fnames: if (num_files % 1) == 0: print("\rFound {} sequences.".format(num_files), end='') if fname.lower().endswith(('mp3', 'wav')): waveform, sample_rate = ta.load( os.path.join(path, fname)) #print(waveform, waveform.size()) if waveform.shape[ 1] > sequence_length: #had to move check to beginning because after transformation with MFCC it is always too short if plot: plt.plot(waveform.t().numpy()) plt.show() dout = transform(waveform).float() # Remove silence from beginning of track #transformed = remove_silence_start_end(transformed) #NEEDS TO BE IMPLEMENTED FOR NON MU DATA dout = one_hot(dout[0, :].int().long(), 256).float().T self.data_out.append(dout) if frontend_config is not None: transformed = auditory_frontend( waveform, sample_rate, frontend_config) if som is not None: transformed = som.transform_mfcc_seq( transformed, transform_onehot=True) self.data.append(transformed.float()) num_files += 1 if self.data and self.data_out: torch.save(self.data, os.path.join(path, '{}.pt'.format(dir_name))) torch.save(self.data_out, os.path.join(path, '{}_out.pt'.format(dir_name))) #print(self.data[0], self.data[0].shape) print('\nBoth datasets saved.') if not self.data and self.data_out: self.data = self.data_out torch.save(self.data, os.path.join(path, '{}.pt'.format(dir_name))) print('\nMu-law one-hot dataset saved.') self.data_out = []
def encoder(quantization_channels): return MuLawEncoding(quantization_channels)