Example #1
0
def test_nnet(nnet, seqlen=100, olap=40, wsz=2049, N=4096, hop=384, B=16):
    nnet[0].eval()
    nnet[1].eval()
    nnet[2].eval()
    nnet[3].eval()
    L = olap/2
    seg = 2
    w = tf.hamming(wsz, True)
    x, fs = Io.wavRead('/home/mis/Documents/Python/Projects/SourceSeparation/testFiles/supreme_test3.wav', mono=True)

    mx, px = tf.TimeFrequencyDecomposition.STFT(x, w, N, hop)

    mx, px, _ = prepare_overlap_sequences(mx, px, mx, seqlen, olap, B)
    vs_out = np.zeros((mx.shape[0], seqlen-olap, wsz), dtype=np.float32)
    mask_out1 = np.zeros((mx.shape[0], seqlen-olap, wsz), dtype=np.float32)

    for batch in xrange(mx.shape[0]/B):
        # Mixture to Singing voice
        H_enc = nnet[0](mx[batch * B: (batch+1)*B, :, :])
        H_j_dec = it_infer.iterative_recurrent_inference(nnet[1], H_enc,
                                                         criterion=None, tol=1e-3, max_iter=10)

        vs_hat, mask = nnet[2](H_j_dec, mx[batch * B: (batch+1)*B, :, :])
        y_out = nnet[3](vs_hat)
        vs_out[batch * B: (batch+1)*B, :, :] = y_out.data.cpu().numpy()

        mask_out1[batch * B: (batch+1)*B, :, :] = mask.data.cpu().numpy()

    vs_out.shape = (vs_out.shape[0]*vs_out.shape[1], wsz)
    mask_out1.shape = (mask_out1.shape[0]*mask_out1.shape[1], wsz)

    if olap == 1:
        mx = np.ascontiguousarray(mx, dtype=np.float32)
        px = np.ascontiguousarray(px, dtype=np.float32)
    else:
        mx = np.ascontiguousarray(mx[:, olap/2:-olap/2, :], dtype=np.float32)
        px = np.ascontiguousarray(px[:, olap/2:-olap/2, :], dtype=np.float32)

    mx.shape = (mx.shape[0]*mx.shape[1], wsz)
    px.shape = (px.shape[0]*px.shape[1], wsz)

    # Approximated sources
    # Iterative G-L algorithm
    for GLiter in range(10):
        y_recb = tf.TimeFrequencyDecomposition.iSTFT(vs_out, px, wsz, hop, True)
        _, px = tf.TimeFrequencyDecomposition.STFT(y_recb, tf.hamming(wsz, True), N, hop)

    x = x[olap/2 * hop:]

    Io.audioWrite(y_recb, 44100, 16, 'results/test_sv.mp3', 'mp3')
    Io.audioWrite(x[:len(y_recb)], 44100, 16, 'results/test_mix.mp3', 'mp3')

    return None
Example #2
0
def test_nnet(nnet, seqlen=100, olap=40, wsz=2049, N=4096, hop=384, B=16):
    """
        Method to test the model on some data. Writes the outcomes in ".wav" format and.
        stores them under the defined results path.
        Args:
            nnet             : (List)      A list containing the Pytorch modules of the skip-filtering model.
            seqlen           : (int)       Length of the time-sequence.
            olap             : (int)       Overlap between spectrogram time-sequences
                                           (to recover the missing information from the context information).
            wsz              : (int)       Window size in samples.
            N                : (int)       The FFT size.
            hop              : (int)       Hop size in samples.
            B                : (int)       Batch size.
    """
    nnet[0].eval()
    nnet[1].eval()
    nnet[2].eval()
    nnet[3].eval()
    L = olap / 2
    w = tf.hamming(wsz, True)
    x, fs = Io.wavRead('results/test_files/test.wav', mono=True)

    mx, px = tf.TimeFrequencyDecomposition.STFT(x, w, N, hop)
    mx, px, _ = prepare_overlap_sequences(mx, px, mx, seqlen, olap, B)
    vs_out = np.zeros((mx.shape[0], seqlen - olap, wsz), dtype=np.float32)

    for batch in xrange(mx.shape[0] / B):
        # Mixture to Singing voice
        H_enc = nnet[0](mx[batch * B:(batch + 1) * B, :, :])
        H_j_dec = it_infer.iterative_recurrent_inference(nnet[1],
                                                         H_enc,
                                                         criterion=None,
                                                         tol=1e-3,
                                                         max_iter=10)

        vs_hat, mask = nnet[2](H_j_dec, mx[batch * B:(batch + 1) * B, :, :])
        y_out = nnet[3](vs_hat)
        vs_out[batch * B:(batch + 1) * B, :, :] = y_out.data.cpu().numpy()

    vs_out.shape = (vs_out.shape[0] * vs_out.shape[1], wsz)

    if olap == 1:
        mx = np.ascontiguousarray(mx, dtype=np.float32)
        px = np.ascontiguousarray(px, dtype=np.float32)
    else:
        mx = np.ascontiguousarray(mx[:, olap / 2:-olap / 2, :],
                                  dtype=np.float32)
        px = np.ascontiguousarray(px[:, olap / 2:-olap / 2, :],
                                  dtype=np.float32)

    mx.shape = (mx.shape[0] * mx.shape[1], wsz)
    px.shape = (px.shape[0] * px.shape[1], wsz)

    # Approximated sources
    # Iterative G-L algorithm
    for GLiter in range(10):
        y_recb = tf.TimeFrequencyDecomposition.iSTFT(vs_out, px, wsz, hop,
                                                     True)
        _, px = tf.TimeFrequencyDecomposition.STFT(y_recb,
                                                   tf.hamming(wsz,
                                                              True), N, hop)

    x = x[olap / 2 * hop:]

    Io.wavWrite(y_recb, 44100, 16, 'results/test_files/test_sv.wav')
    Io.wavWrite(x[:len(y_recb)], 44100, 16, 'results/test_files/test_mix.wav')

    return None
Example #3
0
def test_eval(nnet, B, T, N, L, wsz, hop):
    """
        Method to test the model on the test data. Writes the outcomes in ".wav" format and.
        stores them under the defined results path. Optionally, it performs BSS-Eval using
        MIREval python toolbox (Used only for comparison to BSSEval Matlab implementation).
        The evaluation results are stored under the defined save path.
        Args:
            nnet             : (List)      A list containing the Pytorch modules of the skip-filtering model.
            B                : (int)       Batch size.
            T                : (int)       Length of the time-sequence.
            N                : (int)       The FFT size.
            L                : (int)       Number of context frames from the time-sequence.
            wsz              : (int)       Window size in samples.
            hop              : (int)       Hop size in samples.
    """
    nnet[0].eval()
    nnet[1].eval()
    nnet[2].eval()
    nnet[3].eval()

    def my_res(mx, vx, L, wsz):
        """
            A helper function to reshape data according
            to the context frame.
        """
        mx = np.ascontiguousarray(mx[:, L:-L, :], dtype=np.float32)
        mx.shape = (mx.shape[0] * mx.shape[1], wsz)
        vx = np.ascontiguousarray(vx[:, L:-L, :], dtype=np.float32)
        vx.shape = (vx.shape[0] * vx.shape[1], wsz)

        return mx, vx

    # Paths for loading and storing the test-set
    # Generate full paths for test
    test_sources_list = sorted(os.listdir(sources_path + foldersList[1]))
    test_sources_list = [
        sources_path + foldersList[1] + '/' + i for i in test_sources_list
    ]

    # Initializing the containers of the metrics
    sdr = []
    sir = []
    sar = []

    for indx in xrange(len(test_sources_list)):
        print('Reading:' + test_sources_list[indx])
        # Reading
        bass, _ = Io.wavRead(os.path.join(test_sources_list[indx],
                                          keywords[0]),
                             mono=False)
        drums, _ = Io.wavRead(os.path.join(test_sources_list[indx],
                                           keywords[1]),
                              mono=False)
        oth, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[2]),
                            mono=False)
        vox, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[3]),
                            mono=False)

        bk_true = np.sum(bass + drums + oth, axis=-1) * 0.5
        mix = np.sum(bass + drums + oth + vox, axis=-1) * 0.5
        sv_true = np.sum(vox, axis=-1) * 0.5

        # STFT Analysing
        mx, px = tf.TimeFrequencyDecomposition.STFT(mix, tf.hamming(wsz, True),
                                                    N, hop)

        # Data reshaping (magnitude and phase)
        mx, px, _ = prepare_overlap_sequences(mx, px, px, T, 2 * L, B)

        # The actual "denoising" part
        vx_hat = np.zeros((mx.shape[0], T - L * 2, wsz), dtype=np.float32)

        for batch in xrange(mx.shape[0] / B):
            H_enc = nnet[0](mx[batch * B:(batch + 1) * B, :, :])

            H_j_dec = it_infer.iterative_recurrent_inference(nnet[1],
                                                             H_enc,
                                                             criterion=None,
                                                             tol=1e-3,
                                                             max_iter=10)

            vs_hat, mask = nnet[2](H_j_dec,
                                   mx[batch * B:(batch + 1) * B, :, :])
            y_out = nnet[3](vs_hat)
            vx_hat[batch * B:(batch + 1) * B, :, :] = y_out.data.cpu().numpy()

        # Final reshaping
        vx_hat.shape = (vx_hat.shape[0] * vx_hat.shape[1], wsz)
        mx, px = my_res(mx, px, L, wsz)

        # Time-domain recovery
        # Iterative G-L algorithm
        for GLiter in range(10):
            sv_hat = tf.TimeFrequencyDecomposition.iSTFT(
                vx_hat, px, wsz, hop, True)
            _, px = tf.TimeFrequencyDecomposition.STFT(sv_hat,
                                                       tf.hamming(wsz, True),
                                                       N, hop)

        # Removing the samples that no estimation exists
        mix = mix[L * hop:]
        sv_true = sv_true[L * hop:]
        bk_true = bk_true[L * hop:]

        # Background music estimation
        if len(sv_true) > len(sv_hat):
            bk_hat = mix[:len(sv_hat)] - sv_hat
        else:
            bk_hat = mix - sv_hat[:len(mix)]

        # Disk writing for external BSS_eval using DSD100-tools (used in our paper)
        Io.wavWrite(
            sv_true, 44100, 16,
            os.path.join(save_path, 'tf_true_sv_' + str(indx) + '.wav'))
        Io.wavWrite(
            bk_true, 44100, 16,
            os.path.join(save_path, 'tf_true_bk_' + str(indx) + '.wav'))
        Io.wavWrite(sv_hat, 44100, 16,
                    os.path.join(save_path, 'tf_hat_sv_' + str(indx) + '.wav'))
        Io.wavWrite(bk_hat, 44100, 16,
                    os.path.join(save_path, 'tf_hat_bk_' + str(indx) + '.wav'))
        Io.wavWrite(mix, 44100, 16,
                    os.path.join(save_path, 'tf_mix_' + str(indx) + '.wav'))

        # Internal BSSEval using librosa (just for comparison)
        if len(sv_true) > len(sv_hat):
            c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise(
                [sv_true[:len(sv_hat)], bk_true[:len(sv_hat)]],
                [sv_hat, bk_hat])
        else:
            c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise(
                [sv_true, bk_true],
                [sv_hat[:len(sv_true)], bk_hat[:len(sv_true)]])

        sdr.append(c_sdr)
        sir.append(c_sir)
        sar.append(c_sar)

        # Storing the results iteratively
        pickle.dump(sdr, open(os.path.join(save_path, 'SDR.p'), 'wb'))
        pickle.dump(sir, open(os.path.join(save_path, 'SIR.p'), 'wb'))
        pickle.dump(sar, open(os.path.join(save_path, 'SAR.p'), 'wb'))

    return None
Example #4
0
def main(training, apply_sparsity):
    """
        The main function to train and test.
    """
    # Reproducible results
    np.random.seed(218)
    torch.manual_seed(218)
    torch.cuda.manual_seed(218)
    # Torch model
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    # Analysis
    wsz = 2049  # Window-size
    Ns = 4096  # FFT size
    hop = 384  # Hop size
    fs = 44100  # Sampling frequency

    # Parameters
    B = 16  # Batch-size
    T = 60  # Length of the sequence
    N = 2049  # Frequency sub-bands to be processed
    F = 744  # Frequency sub-bands for encoding
    L = 10  # Context parameter (2*L frames will be removed)
    epochs = 100  # Epochs
    init_lr = 1e-4  # Initial learning rate
    mnorm = 0.5  # L2-based norm clipping
    mask_loss_threshold = 1.5  # Scalar indicating the threshold for the time-frequency masking module
    good_loss_threshold = 0.25  # Scalar indicating the threshold for the source enhancment module

    # Data (Predifined by the DSD100 dataset and the non-instumental/non-bleeding stems of MedleydB)
    totTrainFiles = 116
    numFilesPerTr = 4

    print('------------   Building model   ------------')
    encoder = s_s_net.BiGRUEncoder(B, T, N, F, L)
    decoder = s_s_net.Decoder(B, T, N, F, L, infr=True)
    sp_decoder = s_s_net.SparseDecoder(B, T, N, F, L)
    source_enhancement = s_s_net.SourceEnhancement(B, T, N, F, L)

    encoder.train(mode=True)
    decoder.train(mode=True)
    sp_decoder.train(mode=True)
    source_enhancement.train(mode=True)

    if torch.has_cudnn:
        print('------------   CUDA Enabled   --------------')
        encoder.cuda()
        decoder.cuda()
        sp_decoder.cuda()
        source_enhancement.cuda()

    # Defining objectives
    rec_criterion = loss_functions.kullback_leibler  # Reconstruction criterion

    optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()) +
        list(sp_decoder.parameters()) + list(source_enhancement.parameters()),
        lr=init_lr)

    if training:
        win_viz, winb_viz = visualize.init_visdom()
        batch_loss = []
        # Over epochs
        batch_index = 0
        for epoch in range(epochs):
            print('Epoch: ' + str(epoch + 1))
            epoch_loss = []
            # Over the set of files
            for index in range(totTrainFiles / numFilesPerTr):
                # Get Data
                ms, vs = nnet_helpers.get_data(index + 1, numFilesPerTr, wsz,
                                               Ns, hop, T, L, B)

                # Shuffle data
                shf_indices = np.random.permutation(ms.shape[0])
                ms = ms[shf_indices]
                vs = vs[shf_indices]

                # Over batches
                for batch in tqdm(range(ms.shape[0] / B)):
                    # Mixture to Singing voice
                    H_enc = encoder(ms[batch * B:(batch + 1) * B, :, :])
                    # Iterative inference
                    H_j_dec = it_infer.iterative_recurrent_inference(
                        decoder, H_enc, criterion=None, tol=1e-3, max_iter=10)
                    vs_hat_b = sp_decoder(H_j_dec, ms[batch * B:(batch + 1) *
                                                      B, :, :])[0]
                    vs_hat_b_filt = source_enhancement(vs_hat_b)

                    # Loss
                    if torch.has_cudnn:
                        loss = rec_criterion(
                            Variable(
                                torch.from_numpy(vs[batch * B:(batch + 1) * B,
                                                    L:-L, :]).cuda()),
                            vs_hat_b_filt)

                        loss_mask = rec_criterion(
                            Variable(
                                torch.from_numpy(vs[batch * B:(batch + 1) * B,
                                                    L:-L, :]).cuda()),
                            vs_hat_b)

                        if loss_mask.data[
                                0] >= mask_loss_threshold and loss.data[
                                    0] >= good_loss_threshold:
                            loss += loss_mask

                    else:
                        loss = rec_criterion(
                            Variable(
                                torch.from_numpy(vs[batch * B:(batch + 1) * B,
                                                    L:-L, :])), vs_hat_b_filt)

                        loss_mask = rec_criterion(
                            Variable(
                                torch.from_numpy(vs[batch * B:(batch + 1) * B,
                                                    L:-L, :])), vs_hat_b)

                        if loss_mask.data[
                                0] >= mask_loss_threshold and loss.data[
                                    0] >= good_loss_threshold:
                            loss += loss_mask

                    # Store loss for display and scheduler
                    batch_loss += [loss.data[0]]
                    epoch_loss += [loss.data[0]]

                    # Sparsity term
                    if apply_sparsity:
                        sparsity_penalty = torch.sum(torch.abs(torch.diag(sp_decoder.ffDec.weight.data))) * 1e-2 +\
                                           torch.sum(torch.pow(source_enhancement.ffSe_dec.weight, 2.)) * 1e-4

                        loss += sparsity_penalty

                        winb_viz = visualize.viz.line(
                            X=np.arange(batch_index, batch_index + 1),
                            Y=np.reshape(sparsity_penalty.data[0], (1, )),
                            win=winb_viz,
                            update='append')

                    optimizer.zero_grad()

                    loss.backward()
                    torch.nn.utils.clip_grad_norm(
                        list(encoder.parameters()) +
                        list(decoder.parameters()) +
                        list(sp_decoder.parameters()) +
                        list(source_enhancement.parameters()),
                        max_norm=mnorm,
                        norm_type=2)
                    optimizer.step()
                    # Update graphs
                    win_viz = visualize.viz.line(
                        X=np.arange(batch_index, batch_index + 1),
                        Y=np.reshape(batch_loss[batch_index], (1, )),
                        win=win_viz,
                        update='append')
                    batch_index += 1

            if (epoch + 1) % 40 == 0:
                print('------------   Saving model   ------------')
                torch.save(
                    encoder.state_dict(),
                    'results/torch_sps_encoder_' + str(epoch + 1) + '.pytorch')
                torch.save(
                    decoder.state_dict(),
                    'results/torch_sps_decoder_' + str(epoch + 1) + '.pytorch')
                torch.save(
                    sp_decoder.state_dict(), 'results/torch_sps_sp_decoder_' +
                    str(epoch + 1) + '.pytorch')
                torch.save(
                    source_enhancement.state_dict(),
                    'results/torch_sps_se_' + str(epoch + 1) + '.pytorch')
                print('------------       Done       ------------')
    else:
        print('-------  Loading pre-trained model   -------')
        print('-------  Loading inference weights  -------')
        encoder.load_state_dict(
            torch.load('results/results_inference/torch_sps_encoder.pytorch'))
        decoder.load_state_dict(
            torch.load('results/results_inference/torch_sps_decoder.pytorch'))
        sp_decoder.load_state_dict(
            torch.load(
                'results/results_inference/torch_sps_sp_decoder.pytorch'))
        source_enhancement.load_state_dict(
            torch.load('results/results_inference/torch_sps_se.pytorch'))
        print('-------------      Done        -------------')

    return encoder, decoder, sp_decoder, source_enhancement