Esempio n. 1
0
def main():
    batch_size = 5 * 44100
    batch_n = 100
    dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Dev', batch_size)
    transform = MixTransform()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=8,
                                             num_workers=8,
                                             shuffle=True)
    for batch in dataloader:
        x = transform(batch)
Esempio n. 2
0
def main():
    batch_size = 32
    orig_freq = 44100
    target_freq = 16000
    seconds = 5

    n_fft = 512
    win_length = 512
    hop_length = 128
    freq_bins, spec_time, _ = torch.stft(
        torch.Tensor(seconds * target_freq), n_fft, hop_length, win_length
    ).shape

    dataset = DSD100(
        '/Volumes/Buffalo 2TB/Datasets/DSD100', 'Dev', seconds * orig_freq)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=8, shuffle=True)
    transforms = [
        MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]),
        lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq),
        torchaudio.transforms.Resample(orig_freq, target_freq),
        lambda x: torch.stft(x, n_fft, hop_length, win_length),
        lambda x: x.reshape(x.shape[0] // 3, 3, freq_bins, spec_time, 2),
    ]
    def transform(x):
        for t in transforms:
            x = t(x)
        return x

    model = ChimeraPlusPlus(freq_bins, spec_time, 2, 20)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(10):
        sum_loss = 0
        last_output_len = 0
        for step, batch in enumerate(dataloader):
            batch = transform(batch)
            X = batch[:, 2, :, :, :]
            S = batch[:, :2, :, :, :]
            X_abs = torch.sqrt(torch.sum(X**2, dim=-1))
            S_abs = torch.sqrt(torch.sum(S**2, dim=-1))
            Y = torch.eye(2)[
                torch.argmax(S_abs, dim=1)
                .reshape(batch.shape[0], freq_bins*spec_time)
            ]
            embd, mask = model(torch.log10(X_abs.clamp(min=1e-9)))

            # Compute loss
            loss = loss_dc_whitend(embd, Y) + loss_mi_tpsa(mask, X, S)
            sum_loss += loss.item() * batch.shape[0]
            ave_loss = sum_loss / (batch.shape[0]*(step+1))

            # Zero gradients, perform a backward pass, and update the weights.
            loss.backward()
            optimizer.step()
            sum_grad = sum(
                torch.sum(torch.abs(p.grad))
                for p in model.parameters() if p.grad is not None
            )
            optimizer.zero_grad()

            # Print learning statistics
            curr_output =\
                f'\repoch {epoch} step {step} loss={ave_loss} grad={sum_grad}'
            sys.stdout.write('\r' + ' ' * last_output_len)
            sys.stdout.write(curr_output)
            sys.stdout.flush()
            last_output_len = len(curr_output)

        curr_output =\
            f'\repoch {epoch} loss={ave_loss}'
        sys.stdout.write('\r' + ' ' * last_output_len)
        sys.stdout.write(f'\repoch {epoch} loss={ave_loss}\n')

        torch.save(
            model.state_dict(),
            f'model_epoch{epoch}.pth'
        )
Esempio n. 3
0
def main():
    batch_size = 16
    orig_freq = 44100
    target_freq = 16000
    seconds = 5

    n_fft = 512
    win_length = 512
    hop_length = 128
    freq_bins, spec_time, _ = torch.stft(torch.Tensor(seconds * target_freq),
                                         n_fft,
                                         hop_length,
                                         win_length,
                                         window=torch.hann_window(n_fft)).shape

    dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Dev',
                     seconds * orig_freq)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=8,
                                             shuffle=True)
    transforms = [
        MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]),
        lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq),
        torchaudio.transforms.Resample(orig_freq, target_freq),
        lambda x: x.reshape(x.shape[0] // 3, 3, seconds * target_freq),
    ]

    def transform(x):
        for t in transforms:
            x = t(x)
        return x

    stft = lambda x: torch.stft(x.reshape(x.shape[:-1].numel(), seconds *
                                          target_freq),
                                n_fft,
                                hop_length,
                                win_length,
                                window=torch.hann_window(n_fft)).reshape(
                                    *x.shape[:-1], freq_bins, spec_time, 2)
    comp_mul = lambda X, Y: torch.stack((
        X.unbind(-1)[0] * Y.unbind(-1)[0] - X.unbind(-1)[1] * Y.unbind(-1)[1],
        X.unbind(-1)[0] * Y.unbind(-1)[1] + X.unbind(-1)[1] * Y.unbind(-1)[0]),
                                        dim=-1)

    initial_model = None  #'model-dc.pth'
    initial_epoch = 20  # start at 0
    train_epoch = 10
    loss_function = 'wave'  # 'chimera++', 'mask', 'wave'
    n_misi_layers = 1
    model = ChimeraMagPhasebook(freq_bins, spec_time, 2, 20, N=600)
    if initial_model is not None:
        model.load_state_dict(torch.load(initial_model))
    if initial_epoch > 0:
        model.load_state_dict(torch.load(f'model_epoch{initial_epoch-1}.pth'))
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    misiLayer = MisiNetwork(n_fft,
                            hop_length,
                            win_length,
                            layer_num=n_misi_layers)

    for epoch in range(initial_epoch, initial_epoch + train_epoch):
        sum_loss = 0
        total_batch = 0
        last_output_len = 0
        for step, batch in enumerate(dataloader):
            batch = transform(batch)
            x, s = batch[:, 2, :], batch[:, :2, :]
            X, S = stft(x), stft(s)
            X_abs = torch.sqrt(torch.sum(X**2, dim=-1))
            X_phase = X / X_abs.clamp(min=1e-12).unsqueeze(-1)
            S_abs = torch.sqrt(torch.sum(S**2, dim=-1))
            S_phase = S / S_abs.clamp(min=1e-12).unsqueeze(-1)
            Y = torch.eye(2)[torch.argmax(S_abs, dim=1).reshape(
                batch.shape[0], freq_bins * spec_time)]

            embd, (mask, phasep,
                   com) = model(torch.log10(X_abs.clamp(min=1e-12)),
                                outputs=['mag', 'phasep', 'com'])

            # compute loss
            if loss_function == 'chimera++':
                loss = 0.975 * loss_dc_whitend(embd, Y) \
                    + 0.025 * loss_mi_tpsa(mask, X, S, gamma=2.)
            elif loss_function == 'mask':
                loss = 0.5 * loss_mi_tpsa(mask, X, S, gamma=2.) \
                    + 0.5 * loss_csa(com, X, S)
            elif loss_function == 'wave':
                Shat = comp_mul(com, X.unsqueeze(1))
                shat = misiLayer(Shat, x)
                loss = loss_wa(shat, s)

            sum_loss += loss.item()
            total_batch += batch.shape[0]
            ave_loss = sum_loss / total_batch

            # Zero gradients, perform a backward pass, and update the weights.
            loss.backward()
            optimizer.step()
            sum_grad = sum(
                torch.sum(torch.abs(p.grad)) for p in model.parameters()
                if p.grad is not None)
            optimizer.zero_grad()

            # Print learning statistics
            curr_output =\
                f'\repoch {epoch} step {step} loss={ave_loss} grad={sum_grad}'
            sys.stdout.write('\r' + ' ' * last_output_len)
            sys.stdout.write(curr_output)
            sys.stdout.flush()
            last_output_len = len(curr_output)

        curr_output =\
            f'\repoch {epoch} loss={ave_loss}'
        sys.stdout.write('\r' + ' ' * last_output_len)
        sys.stdout.write(f'\repoch {epoch} loss={ave_loss}\n')

        torch.save(model.state_dict(), f'model_epoch{epoch}.pth')
Esempio n. 4
0
def main():
    model_file = 'model_chimeraplusplus_misi.pth'
    batch_size = 24
    batch_idx = 2
    orig_freq = 44100
    target_freq = 16000
    seconds = 5

    n_fft = 512
    win_length = 512
    hop_length = 128

    freq_bins, spec_time, _ = torch.stft(torch.Tensor(seconds * target_freq),
                                         n_fft, hop_length, win_length).shape

    dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Test',
                     seconds * orig_freq)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=8,
                                             shuffle=False)
    transforms = [
        MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]),
        lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq),
        torchaudio.transforms.Resample(orig_freq, target_freq),
        lambda x: torch.stft(x, n_fft, hop_length, win_length),
        lambda x: x.reshape(x.shape[0] // 3, 3, freq_bins, spec_time, 2),
    ]

    def transform(x):
        for t in transforms:
            x = t(x)
        return x

    model = ChimeraPlusPlus(freq_bins,
                            spec_time,
                            2,
                            20,
                            activation='convex_softmax')
    model.load_state_dict(torch.load(model_file))
    misilayer = MisiLayer(n_fft, hop_length, win_length, 5)

    #batch = transform(next(iter(dataloader)))
    batch = transform(dataset[list(
        range(batch_idx * batch_size, (batch_idx + 1) * batch_size))])
    S = batch[:, :2, :, :, :]
    X = batch[:, 2, :, :, :]
    X_abs = torch.sqrt(torch.sum(X**2, dim=-1))
    X_phase = X / X_abs.clamp(min=1e-9).unsqueeze(-1)
    x = torchaudio.functional.istft(X, n_fft, hop_length, win_length)

    _, mask = model(torch.log10(X_abs.clamp(min=1e-9)))
    mask = mask.detach()
    Shat_abs, Shat_phase = mask * X_abs.unsqueeze(1), X_phase.unsqueeze(1)
    Shat_phase = misilayer(Shat_abs, Shat_phase, x)
    Shat = Shat_abs.unsqueeze(-1) * X_phase.unsqueeze(1)

    s = torchaudio.functional.istft(
        S.reshape(batch_size*2, freq_bins, spec_time, 2),
        n_fft, hop_length, win_length
    ).reshape(batch_size, 2, seconds * target_freq).transpose(0, 1) \
    .reshape(2, batch_size * seconds * target_freq)
    shat = torchaudio.functional.istft(
        Shat.reshape(batch_size*2, freq_bins, spec_time, 2),
        n_fft, hop_length, win_length
    ).reshape(batch_size, 2, seconds * target_freq).transpose(0, 1) \
    .reshape(2, batch_size * seconds * target_freq)

    for i_channel, (_s, _shat) in enumerate(zip(s, shat)):
        torchaudio.save(f's_{i_channel}.wav', _s, target_freq)
        torchaudio.save(f'shat_{i_channel}.wav', _shat, target_freq)
Esempio n. 5
0
def main():
    batch_size = 32
    orig_freq = 44100
    target_freq = 16000
    seconds = 5

    n_fft = 512
    win_length = 512
    hop_length = 128
    freq_bins, spec_time, _ = torch.stft(torch.Tensor(seconds * target_freq),
                                         n_fft, hop_length, win_length).shape

    dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Dev',
                     seconds * orig_freq)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=8,
                                             shuffle=True)
    transforms = [
        MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]),
        lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq),
        torchaudio.transforms.Resample(orig_freq, target_freq),
        lambda x: torch.stft(x, n_fft, hop_length, win_length),
        lambda x: x.reshape(x.shape[0] // 3, 3, freq_bins, spec_time, 2),
    ]

    def transform(x):
        for t in transforms:
            x = t(x)
        return x

    model = ChimeraPlusPlus(freq_bins,
                            spec_time,
                            2,
                            20,
                            activation='convex_softmax')
    #initial_epoch = 54
    #model.load_state_dict(torch.load(f'model_epoch{initial_epoch}.pth'))
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(60):
        sum_loss = 0
        last_output_len = 0
        for step, batch in enumerate(dataloader):
            batch = transform(batch)
            X = batch[:, 2, :, :, :]
            S = batch[:, :2, :, :, :]
            X_abs = torch.sqrt(torch.sum(X**2, dim=-1))
            X_phase = X / X_abs.clamp(min=1e-12).unsqueeze(-1)
            x = torchaudio.functional.istft(X, n_fft, hop_length, win_length)
            s = torchaudio.functional.istft(
                S.reshape(batch.shape[0] * 2, freq_bins, spec_time, 2), n_fft,
                hop_length, win_length).reshape(batch.shape[0], 2,
                                                seconds * target_freq)

            S_abs = torch.sqrt(torch.sum(S**2, dim=-1))
            Y = torch.eye(2)[torch.argmax(S_abs, dim=1).reshape(
                batch.shape[0], freq_bins * spec_time)]
            embd, mask = model(torch.log10(X_abs.clamp(min=1e-12)))
            amphat = mask * X_abs.unsqueeze(1)
            phasehat = X_phase.unsqueeze(1)

            # compute loss
            if epoch < 45:
                loss = 0.975 * loss_dc_whitend(embd, Y) + 0.025 * loss_mi_tpsa(
                    mask, X, S, gamma=2.)
            elif initial_epoch < 55:
                loss = loss_mi_tpsa(mask, X, S, gamma=2.)
            elif epoch <= 60:
                for i in range(epoch - initial_epoch):
                    l = MisiLayer(n_fft, hop_length, win_length)
                    phasehat = l(amphat, phasehat, x)
                shat = torchaudio.functional.istft(
                    (amphat.unsqueeze(-1) * phasehat).reshape(
                        batch.shape[0] * 2, freq_bins, spec_time, 2), n_fft,
                    hop_length, win_length).reshape(batch.shape[0], 2,
                                                    seconds * target_freq)
                loss = loss_wa(shat, s)

            sum_loss += loss.item() * batch.shape[0]
            ave_loss = sum_loss / (batch.shape[0] * (step + 1))

            # Zero gradients, perform a backward pass, and update the weights.
            loss.backward()
            optimizer.step()
            sum_grad = sum(
                torch.sum(torch.abs(p.grad)) for p in model.parameters()
                if p.grad is not None)
            optimizer.zero_grad()

            # Print learning statistics
            curr_output =\
                f'\repoch {epoch} step {step} loss={ave_loss} grad={sum_grad}'
            sys.stdout.write('\r' + ' ' * last_output_len)
            sys.stdout.write(curr_output)
            sys.stdout.flush()
            last_output_len = len(curr_output)

        curr_output =\
            f'\repoch {epoch} loss={ave_loss}'
        sys.stdout.write('\r' + ' ' * last_output_len)
        sys.stdout.write(f'\repoch {epoch} loss={ave_loss}\n')

        torch.save(model.state_dict(), f'model_epoch{epoch}.pth')
Esempio n. 6
0
def main():
    model_file = 'model_epoch9.pth'
    batch_size = 4
    batch_idx = 2
    orig_freq = 44100
    target_freq = 16000
    seconds = 5

    n_fft = 512
    win_length = 512
    hop_length = 128

    freq_bins, spec_time, _ = torch.stft(torch.Tensor(seconds * target_freq),
                                         n_fft, hop_length, win_length).shape

    stft = lambda x: torch.stft(x.reshape(x.shape[:-1].numel(), seconds *
                                          target_freq),
                                n_fft,
                                hop_length,
                                win_length,
                                window=torch.hann_window(n_fft)).reshape(
                                    *x.shape[:-1], freq_bins, spec_time, 2)
    istft = lambda X: torchaudio.functional.istft(
        X.reshape(X.shape[:-3].numel(), freq_bins, spec_time, 2),
        n_fft,
        hop_length,
        win_length,
        window=torch.hann_window(n_fft)).reshape(*X.shape[:-3], waveform_length
                                                 )
    comp_mul = lambda X, Y: torch.stack((
        X.unbind(-1)[0] * Y.unbind(-1)[0] - X.unbind(-1)[1] * Y.unbind(-1)[1],
        X.unbind(-1)[0] * Y.unbind(-1)[1] + X.unbind(-1)[1] * Y.unbind(-1)[0]),
                                        dim=-1)

    dataset = DSD100('/Volumes/Buffalo 2TB/Datasets/DSD100', 'Test',
                     seconds * orig_freq)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=8,
                                             shuffle=False)
    transforms = [
        MixTransform([(0, 1, 2), 3, (0, 1, 2, 3)]),
        lambda x: x.reshape(x.shape[0] * 3, seconds * orig_freq),
        torchaudio.transforms.Resample(orig_freq, target_freq),
        lambda x: torch.stft(
            x, n_fft, hop_length, win_length, window=torch.hann_window(n_fft)),
        lambda x: x.reshape(x.shape[0] // 3, 3, freq_bins, spec_time, 2),
    ]

    def transform(x):
        for t in transforms:
            x = t(x)
        return x

    model = ChimeraMagPhasebook(freq_bins, spec_time, 2, 20, N=600)
    model.load_state_dict(torch.load(model_file))
    misi_layer = MisiNetwork(n_fft, hop_length, win_length, 1)

    #batch = transform(next(iter(dataloader)))
    batch = transform(dataset[list(
        range(batch_idx * batch_size, (batch_idx + 1) * batch_size))])
    S = batch[:, :2, :, :, :]
    X = batch[:, 2, :, :, :]
    X_abs = torch.sqrt(torch.sum(X**2, dim=-1))
    X_phase = X / X_abs.clamp(min=1e-9).unsqueeze(-1)
    x = torchaudio.functional.istft(X,
                                    n_fft,
                                    hop_length,
                                    win_length,
                                    window=torch.hann_window(n_fft))

    _, (com, ) = model(torch.log10(X_abs.clamp(min=1e-9)), outputs=['com'])
    com = com.detach()
    Shat = comp_mul(com, X.unsqueeze(1))
    shat = misi_layer(Shat, x)
    s = torchaudio.functional.istft(S.reshape(batch_size * 2, freq_bins,
                                              spec_time, 2),
                                    n_fft,
                                    hop_length,
                                    win_length,
                                    window=torch.hann_window(n_fft)).reshape(
                                        batch_size, 2, seconds * target_freq)

    print(eval_snr(shat, s))
    print(eval_si_sdr(shat, s))

    shat = shat.transpose(0, 1).reshape(2, batch_size * seconds * target_freq)
    s = s.transpose(0, 1).reshape(2, batch_size * seconds * target_freq)

    for i_channel, (_s, _shat) in enumerate(zip(s, shat)):
        torchaudio.save(f's_{i_channel}.wav', _s, target_freq)
        torchaudio.save(f'shat_{i_channel}.wav', _shat, target_freq)