def main(): batch_size = 5 * 44100 batch_n = 100 dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Dev', batch_size) transform = MixTransform() dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=8, shuffle=True) for batch in dataloader: x = transform(batch)
def main(): batch_size = 32 orig_freq = 44100 target_freq = 16000 seconds = 5 n_fft = 512 win_length = 512 hop_length = 128 freq_bins, spec_time, _ = torch.stft( torch.Tensor(seconds * target_freq), n_fft, hop_length, win_length ).shape dataset = DSD100( '/Volumes/Buffalo 2TB/Datasets/DSD100', 'Dev', seconds * orig_freq) dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=8, shuffle=True) transforms = [ MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]), lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq), torchaudio.transforms.Resample(orig_freq, target_freq), lambda x: torch.stft(x, n_fft, hop_length, win_length), lambda x: x.reshape(x.shape[0] // 3, 3, freq_bins, spec_time, 2), ] def transform(x): for t in transforms: x = t(x) return x model = ChimeraPlusPlus(freq_bins, spec_time, 2, 20) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(10): sum_loss = 0 last_output_len = 0 for step, batch in enumerate(dataloader): batch = transform(batch) X = batch[:, 2, :, :, :] S = batch[:, :2, :, :, :] X_abs = torch.sqrt(torch.sum(X**2, dim=-1)) S_abs = torch.sqrt(torch.sum(S**2, dim=-1)) Y = torch.eye(2)[ torch.argmax(S_abs, dim=1) .reshape(batch.shape[0], freq_bins*spec_time) ] embd, mask = model(torch.log10(X_abs.clamp(min=1e-9))) # Compute loss loss = loss_dc_whitend(embd, Y) + loss_mi_tpsa(mask, X, S) sum_loss += loss.item() * batch.shape[0] ave_loss = sum_loss / (batch.shape[0]*(step+1)) # Zero gradients, perform a backward pass, and update the weights. loss.backward() optimizer.step() sum_grad = sum( torch.sum(torch.abs(p.grad)) for p in model.parameters() if p.grad is not None ) optimizer.zero_grad() # Print learning statistics curr_output =\ f'\repoch {epoch} step {step} loss={ave_loss} grad={sum_grad}' sys.stdout.write('\r' + ' ' * last_output_len) sys.stdout.write(curr_output) sys.stdout.flush() last_output_len = len(curr_output) curr_output =\ f'\repoch {epoch} loss={ave_loss}' sys.stdout.write('\r' + ' ' * last_output_len) sys.stdout.write(f'\repoch {epoch} loss={ave_loss}\n') torch.save( model.state_dict(), f'model_epoch{epoch}.pth' )
def main(): batch_size = 16 orig_freq = 44100 target_freq = 16000 seconds = 5 n_fft = 512 win_length = 512 hop_length = 128 freq_bins, spec_time, _ = torch.stft(torch.Tensor(seconds * target_freq), n_fft, hop_length, win_length, window=torch.hann_window(n_fft)).shape dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Dev', seconds * orig_freq) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=8, shuffle=True) transforms = [ MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]), lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq), torchaudio.transforms.Resample(orig_freq, target_freq), lambda x: x.reshape(x.shape[0] // 3, 3, seconds * target_freq), ] def transform(x): for t in transforms: x = t(x) return x stft = lambda x: torch.stft(x.reshape(x.shape[:-1].numel(), seconds * target_freq), n_fft, hop_length, win_length, window=torch.hann_window(n_fft)).reshape( *x.shape[:-1], freq_bins, spec_time, 2) comp_mul = lambda X, Y: torch.stack(( X.unbind(-1)[0] * Y.unbind(-1)[0] - X.unbind(-1)[1] * Y.unbind(-1)[1], X.unbind(-1)[0] * Y.unbind(-1)[1] + X.unbind(-1)[1] * Y.unbind(-1)[0]), dim=-1) initial_model = None #'model-dc.pth' initial_epoch = 20 # start at 0 train_epoch = 10 loss_function = 'wave' # 'chimera++', 'mask', 'wave' n_misi_layers = 1 model = ChimeraMagPhasebook(freq_bins, spec_time, 2, 20, N=600) if initial_model is not None: model.load_state_dict(torch.load(initial_model)) if initial_epoch > 0: model.load_state_dict(torch.load(f'model_epoch{initial_epoch-1}.pth')) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) misiLayer = MisiNetwork(n_fft, hop_length, win_length, layer_num=n_misi_layers) for epoch in range(initial_epoch, initial_epoch + train_epoch): sum_loss = 0 total_batch = 0 last_output_len = 0 for step, batch in enumerate(dataloader): batch = transform(batch) x, s = batch[:, 2, :], batch[:, :2, :] X, S = stft(x), stft(s) X_abs = torch.sqrt(torch.sum(X**2, dim=-1)) X_phase = X / X_abs.clamp(min=1e-12).unsqueeze(-1) S_abs = torch.sqrt(torch.sum(S**2, dim=-1)) S_phase = S / S_abs.clamp(min=1e-12).unsqueeze(-1) Y = torch.eye(2)[torch.argmax(S_abs, dim=1).reshape( batch.shape[0], freq_bins * spec_time)] embd, (mask, phasep, com) = model(torch.log10(X_abs.clamp(min=1e-12)), outputs=['mag', 'phasep', 'com']) # compute loss if loss_function == 'chimera++': loss = 0.975 * loss_dc_whitend(embd, Y) \ + 0.025 * loss_mi_tpsa(mask, X, S, gamma=2.) elif loss_function == 'mask': loss = 0.5 * loss_mi_tpsa(mask, X, S, gamma=2.) \ + 0.5 * loss_csa(com, X, S) elif loss_function == 'wave': Shat = comp_mul(com, X.unsqueeze(1)) shat = misiLayer(Shat, x) loss = loss_wa(shat, s) sum_loss += loss.item() total_batch += batch.shape[0] ave_loss = sum_loss / total_batch # Zero gradients, perform a backward pass, and update the weights. loss.backward() optimizer.step() sum_grad = sum( torch.sum(torch.abs(p.grad)) for p in model.parameters() if p.grad is not None) optimizer.zero_grad() # Print learning statistics curr_output =\ f'\repoch {epoch} step {step} loss={ave_loss} grad={sum_grad}' sys.stdout.write('\r' + ' ' * last_output_len) sys.stdout.write(curr_output) sys.stdout.flush() last_output_len = len(curr_output) curr_output =\ f'\repoch {epoch} loss={ave_loss}' sys.stdout.write('\r' + ' ' * last_output_len) sys.stdout.write(f'\repoch {epoch} loss={ave_loss}\n') torch.save(model.state_dict(), f'model_epoch{epoch}.pth')
def main(): model_file = 'model_chimeraplusplus_misi.pth' batch_size = 24 batch_idx = 2 orig_freq = 44100 target_freq = 16000 seconds = 5 n_fft = 512 win_length = 512 hop_length = 128 freq_bins, spec_time, _ = torch.stft(torch.Tensor(seconds * target_freq), n_fft, hop_length, win_length).shape dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Test', seconds * orig_freq) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=8, shuffle=False) transforms = [ MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]), lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq), torchaudio.transforms.Resample(orig_freq, target_freq), lambda x: torch.stft(x, n_fft, hop_length, win_length), lambda x: x.reshape(x.shape[0] // 3, 3, freq_bins, spec_time, 2), ] def transform(x): for t in transforms: x = t(x) return x model = ChimeraPlusPlus(freq_bins, spec_time, 2, 20, activation='convex_softmax') model.load_state_dict(torch.load(model_file)) misilayer = MisiLayer(n_fft, hop_length, win_length, 5) #batch = transform(next(iter(dataloader))) batch = transform(dataset[list( range(batch_idx * batch_size, (batch_idx + 1) * batch_size))]) S = batch[:, :2, :, :, :] X = batch[:, 2, :, :, :] X_abs = torch.sqrt(torch.sum(X**2, dim=-1)) X_phase = X / X_abs.clamp(min=1e-9).unsqueeze(-1) x = torchaudio.functional.istft(X, n_fft, hop_length, win_length) _, mask = model(torch.log10(X_abs.clamp(min=1e-9))) mask = mask.detach() Shat_abs, Shat_phase = mask * X_abs.unsqueeze(1), X_phase.unsqueeze(1) Shat_phase = misilayer(Shat_abs, Shat_phase, x) Shat = Shat_abs.unsqueeze(-1) * X_phase.unsqueeze(1) s = torchaudio.functional.istft( S.reshape(batch_size*2, freq_bins, spec_time, 2), n_fft, hop_length, win_length ).reshape(batch_size, 2, seconds * target_freq).transpose(0, 1) \ .reshape(2, batch_size * seconds * target_freq) shat = torchaudio.functional.istft( Shat.reshape(batch_size*2, freq_bins, spec_time, 2), n_fft, hop_length, win_length ).reshape(batch_size, 2, seconds * target_freq).transpose(0, 1) \ .reshape(2, batch_size * seconds * target_freq) for i_channel, (_s, _shat) in enumerate(zip(s, shat)): torchaudio.save(f's_{i_channel}.wav', _s, target_freq) torchaudio.save(f'shat_{i_channel}.wav', _shat, target_freq)
def main(): batch_size = 32 orig_freq = 44100 target_freq = 16000 seconds = 5 n_fft = 512 win_length = 512 hop_length = 128 freq_bins, spec_time, _ = torch.stft(torch.Tensor(seconds * target_freq), n_fft, hop_length, win_length).shape dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Dev', seconds * orig_freq) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=8, shuffle=True) transforms = [ MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]), lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq), torchaudio.transforms.Resample(orig_freq, target_freq), lambda x: torch.stft(x, n_fft, hop_length, win_length), lambda x: x.reshape(x.shape[0] // 3, 3, freq_bins, spec_time, 2), ] def transform(x): for t in transforms: x = t(x) return x model = ChimeraPlusPlus(freq_bins, spec_time, 2, 20, activation='convex_softmax') #initial_epoch = 54 #model.load_state_dict(torch.load(f'model_epoch{initial_epoch}.pth')) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(60): sum_loss = 0 last_output_len = 0 for step, batch in enumerate(dataloader): batch = transform(batch) X = batch[:, 2, :, :, :] S = batch[:, :2, :, :, :] X_abs = torch.sqrt(torch.sum(X**2, dim=-1)) X_phase = X / X_abs.clamp(min=1e-12).unsqueeze(-1) x = torchaudio.functional.istft(X, n_fft, hop_length, win_length) s = torchaudio.functional.istft( S.reshape(batch.shape[0] * 2, freq_bins, spec_time, 2), n_fft, hop_length, win_length).reshape(batch.shape[0], 2, seconds * target_freq) S_abs = torch.sqrt(torch.sum(S**2, dim=-1)) Y = torch.eye(2)[torch.argmax(S_abs, dim=1).reshape( batch.shape[0], freq_bins * spec_time)] embd, mask = model(torch.log10(X_abs.clamp(min=1e-12))) amphat = mask * X_abs.unsqueeze(1) phasehat = X_phase.unsqueeze(1) # compute loss if epoch < 45: loss = 0.975 * loss_dc_whitend(embd, Y) + 0.025 * loss_mi_tpsa( mask, X, S, gamma=2.) elif initial_epoch < 55: loss = loss_mi_tpsa(mask, X, S, gamma=2.) elif epoch <= 60: for i in range(epoch - initial_epoch): l = MisiLayer(n_fft, hop_length, win_length) phasehat = l(amphat, phasehat, x) shat = torchaudio.functional.istft( (amphat.unsqueeze(-1) * phasehat).reshape( batch.shape[0] * 2, freq_bins, spec_time, 2), n_fft, hop_length, win_length).reshape(batch.shape[0], 2, seconds * target_freq) loss = loss_wa(shat, s) sum_loss += loss.item() * batch.shape[0] ave_loss = sum_loss / (batch.shape[0] * (step + 1)) # Zero gradients, perform a backward pass, and update the weights. loss.backward() optimizer.step() sum_grad = sum( torch.sum(torch.abs(p.grad)) for p in model.parameters() if p.grad is not None) optimizer.zero_grad() # Print learning statistics curr_output =\ f'\repoch {epoch} step {step} loss={ave_loss} grad={sum_grad}' sys.stdout.write('\r' + ' ' * last_output_len) sys.stdout.write(curr_output) sys.stdout.flush() last_output_len = len(curr_output) curr_output =\ f'\repoch {epoch} loss={ave_loss}' sys.stdout.write('\r' + ' ' * last_output_len) sys.stdout.write(f'\repoch {epoch} loss={ave_loss}\n') torch.save(model.state_dict(), f'model_epoch{epoch}.pth')
def main(): model_file = 'model_epoch9.pth' batch_size = 4 batch_idx = 2 orig_freq = 44100 target_freq = 16000 seconds = 5 n_fft = 512 win_length = 512 hop_length = 128 freq_bins, spec_time, _ = torch.stft(torch.Tensor(seconds * target_freq), n_fft, hop_length, win_length).shape stft = lambda x: torch.stft(x.reshape(x.shape[:-1].numel(), seconds * target_freq), n_fft, hop_length, win_length, window=torch.hann_window(n_fft)).reshape( *x.shape[:-1], freq_bins, spec_time, 2) istft = lambda X: torchaudio.functional.istft( X.reshape(X.shape[:-3].numel(), freq_bins, spec_time, 2), n_fft, hop_length, win_length, window=torch.hann_window(n_fft)).reshape(*X.shape[:-3], waveform_length ) comp_mul = lambda X, Y: torch.stack(( X.unbind(-1)[0] * Y.unbind(-1)[0] - X.unbind(-1)[1] * Y.unbind(-1)[1], X.unbind(-1)[0] * Y.unbind(-1)[1] + X.unbind(-1)[1] * Y.unbind(-1)[0]), dim=-1) dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Test', seconds * orig_freq) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=8, shuffle=False) transforms = [ MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]), lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq), torchaudio.transforms.Resample(orig_freq, target_freq), lambda x: torch.stft( x, n_fft, hop_length, win_length, window=torch.hann_window(n_fft)), lambda x: x.reshape(x.shape[0] // 3, 3, freq_bins, spec_time, 2), ] def transform(x): for t in transforms: x = t(x) return x model = ChimeraMagPhasebook(freq_bins, spec_time, 2, 20, N=600) model.load_state_dict(torch.load(model_file)) misi_layer = MisiNetwork(n_fft, hop_length, win_length, 1) #batch = transform(next(iter(dataloader))) batch = transform(dataset[list( range(batch_idx * batch_size, (batch_idx + 1) * batch_size))]) S = batch[:, :2, :, :, :] X = batch[:, 2, :, :, :] X_abs = torch.sqrt(torch.sum(X**2, dim=-1)) X_phase = X / X_abs.clamp(min=1e-9).unsqueeze(-1) x = torchaudio.functional.istft(X, n_fft, hop_length, win_length, window=torch.hann_window(n_fft)) _, (com, ) = model(torch.log10(X_abs.clamp(min=1e-9)), outputs=['com']) com = com.detach() Shat = comp_mul(com, X.unsqueeze(1)) shat = misi_layer(Shat, x) s = torchaudio.functional.istft(S.reshape(batch_size * 2, freq_bins, spec_time, 2), n_fft, hop_length, win_length, window=torch.hann_window(n_fft)).reshape( batch_size, 2, seconds * target_freq) print(eval_snr(shat, s)) print(eval_si_sdr(shat, s)) shat = shat.transpose(0, 1).reshape(2, batch_size * seconds * target_freq) s = s.transpose(0, 1).reshape(2, batch_size * seconds * target_freq) for i_channel, (_s, _shat) in enumerate(zip(s, shat)): torchaudio.save(f's_{i_channel}.wav', _s, target_freq) torchaudio.save(f'shat_{i_channel}.wav', _shat, target_freq)