コード例 #1
0
ファイル: nnet_helpers.py プロジェクト: taksau/mss_pytorch
def test_nnet(nnet, seqlen=100, olap=40, wsz=2049, N=4096, hop=384, B=16):
    nnet[0].eval()
    nnet[1].eval()
    nnet[2].eval()
    nnet[3].eval()
    L = olap/2
    seg = 2
    w = tf.hamming(wsz, True)
    x, fs = Io.wavRead('/home/mis/Documents/Python/Projects/SourceSeparation/testFiles/supreme_test3.wav', mono=True)

    mx, px = tf.TimeFrequencyDecomposition.STFT(x, w, N, hop)

    mx, px, _ = prepare_overlap_sequences(mx, px, mx, seqlen, olap, B)
    vs_out = np.zeros((mx.shape[0], seqlen-olap, wsz), dtype=np.float32)
    mask_out1 = np.zeros((mx.shape[0], seqlen-olap, wsz), dtype=np.float32)

    for batch in xrange(mx.shape[0]/B):
        # Mixture to Singing voice
        H_enc = nnet[0](mx[batch * B: (batch+1)*B, :, :])
        H_j_dec = it_infer.iterative_recurrent_inference(nnet[1], H_enc,
                                                         criterion=None, tol=1e-3, max_iter=10)

        vs_hat, mask = nnet[2](H_j_dec, mx[batch * B: (batch+1)*B, :, :])
        y_out = nnet[3](vs_hat)
        vs_out[batch * B: (batch+1)*B, :, :] = y_out.data.cpu().numpy()

        mask_out1[batch * B: (batch+1)*B, :, :] = mask.data.cpu().numpy()

    vs_out.shape = (vs_out.shape[0]*vs_out.shape[1], wsz)
    mask_out1.shape = (mask_out1.shape[0]*mask_out1.shape[1], wsz)

    if olap == 1:
        mx = np.ascontiguousarray(mx, dtype=np.float32)
        px = np.ascontiguousarray(px, dtype=np.float32)
    else:
        mx = np.ascontiguousarray(mx[:, olap/2:-olap/2, :], dtype=np.float32)
        px = np.ascontiguousarray(px[:, olap/2:-olap/2, :], dtype=np.float32)

    mx.shape = (mx.shape[0]*mx.shape[1], wsz)
    px.shape = (px.shape[0]*px.shape[1], wsz)

    # Approximated sources
    # Iterative G-L algorithm
    for GLiter in range(10):
        y_recb = tf.TimeFrequencyDecomposition.iSTFT(vs_out, px, wsz, hop, True)
        _, px = tf.TimeFrequencyDecomposition.STFT(y_recb, tf.hamming(wsz, True), N, hop)

    x = x[olap/2 * hop:]

    Io.audioWrite(y_recb, 44100, 16, 'results/test_sv.mp3', 'mp3')
    Io.audioWrite(x[:len(y_recb)], 44100, 16, 'results/test_mix.mp3', 'mp3')

    return None
コード例 #2
0
ファイル: nnet_helpers.py プロジェクト: odol0503/mss_pytorch
def get_data(current_set,
             set_size,
             wsz=2049,
             N=4096,
             hop=384,
             T=100,
             L=20,
             B=16):
    """
        Method to acquire training data. The STFT analysis is included.
        Args:
            current_set      : (int)       An integer denoting the current training set.
            set_size         : (int)       The amount of files a set has.
            wsz              : (int)       Window size in samples.
            N                : (int)       The FFT size.
            hop              : (int)       Hop size in samples.
            T                : (int)       Length of the time-sequence.
            L                : (int)       Number of context frames from the time-sequence.
            B                : (int)       Batch size.

        Returns:
            ms_train        :  (3D Array)  Mixture magnitude training data, for the current set.
            vs_train        :  (3D Array)  Singing voice magnitude training data, for the current set.

    """

    # Generate full paths for dev and test
    dev_mixtures_list = sorted(os.listdir(mixtures_path + foldersList[0]))
    dev_mixtures_list = [
        mixtures_path + foldersList[0] + '/' + i for i in dev_mixtures_list
    ]
    dev_sources_list = sorted(os.listdir(sources_path + foldersList[0]))
    dev_sources_list = [
        sources_path + foldersList[0] + '/' + i for i in dev_sources_list
    ]

    # Current lists for training
    c_train_slist = dev_sources_list[(current_set - 1) * set_size:current_set *
                                     set_size]
    c_train_mlist = dev_mixtures_list[(current_set - 1) *
                                      set_size:current_set * set_size]

    for index in range(len(c_train_mlist)):

        # print('Reading:' + c_train_mlist[index])

        # Reading
        vox, _ = Io.wavRead(os.path.join(c_train_slist[index], keywords[3]),
                            mono=False)
        mix, _ = Io.wavRead(os.path.join(c_train_mlist[index], keywords[4]),
                            mono=False)

        # STFT Analysing
        ms_seg, _ = tf.TimeFrequencyDecomposition.STFT(
            0.5 * np.sum(mix, axis=-1), tf.hamming(wsz, True), N, hop)
        vs_seg, _ = tf.TimeFrequencyDecomposition.STFT(
            0.5 * np.sum(vox, axis=-1), tf.hamming(wsz, True), N, hop)

        # Remove null frames
        ms_seg = ms_seg[3:-3, :]
        vs_seg = vs_seg[3:-3, :]

        # Stack some spectrograms and fit
        if index == 0:
            ms_train = ms_seg
            vs_train = vs_seg
        else:
            ms_train = np.vstack((ms_train, ms_seg))
            vs_train = np.vstack((vs_train, vs_seg))

    # Data preprocessing
    # Freeing up some memory
    ms_seg = None
    vs_seg = None

    # Learning the filtering process
    mask = Fm(ms_train, vs_train, ms_train, [], [], alpha=1., method='IRM')
    vs_train = mask()
    vs_train *= 2.
    vs_train = np.clip(vs_train, a_min=0., a_max=1.)
    ms_train = np.clip(ms_train, a_min=0., a_max=1.)
    mask = None
    ms_train, vs_train, _ = prepare_overlap_sequences(ms_train, vs_train,
                                                      ms_train, T, L * 2, B)

    return ms_train, vs_train
コード例 #3
0
ファイル: nnet_helpers.py プロジェクト: odol0503/mss_pytorch
def test_nnet(nnet, seqlen=100, olap=40, wsz=2049, N=4096, hop=384, B=16):
    """
        Method to test the model on some data. Writes the outcomes in ".wav" format and.
        stores them under the defined results path.
        Args:
            nnet             : (List)      A list containing the Pytorch modules of the skip-filtering model.
            seqlen           : (int)       Length of the time-sequence.
            olap             : (int)       Overlap between spectrogram time-sequences
                                           (to recover the missing information from the context information).
            wsz              : (int)       Window size in samples.
            N                : (int)       The FFT size.
            hop              : (int)       Hop size in samples.
            B                : (int)       Batch size.
    """
    nnet[0].eval()
    nnet[1].eval()
    nnet[2].eval()
    nnet[3].eval()
    L = olap / 2
    w = tf.hamming(wsz, True)
    x, fs = Io.wavRead('results/test_files/test.wav', mono=True)

    mx, px = tf.TimeFrequencyDecomposition.STFT(x, w, N, hop)
    mx, px, _ = prepare_overlap_sequences(mx, px, mx, seqlen, olap, B)
    vs_out = np.zeros((mx.shape[0], seqlen - olap, wsz), dtype=np.float32)

    for batch in xrange(mx.shape[0] / B):
        # Mixture to Singing voice
        H_enc = nnet[0](mx[batch * B:(batch + 1) * B, :, :])
        H_j_dec = it_infer.iterative_recurrent_inference(nnet[1],
                                                         H_enc,
                                                         criterion=None,
                                                         tol=1e-3,
                                                         max_iter=10)

        vs_hat, mask = nnet[2](H_j_dec, mx[batch * B:(batch + 1) * B, :, :])
        y_out = nnet[3](vs_hat)
        vs_out[batch * B:(batch + 1) * B, :, :] = y_out.data.cpu().numpy()

    vs_out.shape = (vs_out.shape[0] * vs_out.shape[1], wsz)

    if olap == 1:
        mx = np.ascontiguousarray(mx, dtype=np.float32)
        px = np.ascontiguousarray(px, dtype=np.float32)
    else:
        mx = np.ascontiguousarray(mx[:, olap / 2:-olap / 2, :],
                                  dtype=np.float32)
        px = np.ascontiguousarray(px[:, olap / 2:-olap / 2, :],
                                  dtype=np.float32)

    mx.shape = (mx.shape[0] * mx.shape[1], wsz)
    px.shape = (px.shape[0] * px.shape[1], wsz)

    # Approximated sources
    # Iterative G-L algorithm
    for GLiter in range(10):
        y_recb = tf.TimeFrequencyDecomposition.iSTFT(vs_out, px, wsz, hop,
                                                     True)
        _, px = tf.TimeFrequencyDecomposition.STFT(y_recb,
                                                   tf.hamming(wsz,
                                                              True), N, hop)

    x = x[olap / 2 * hop:]

    Io.wavWrite(y_recb, 44100, 16, 'results/test_files/test_sv.wav')
    Io.wavWrite(x[:len(y_recb)], 44100, 16, 'results/test_files/test_mix.wav')

    return None
コード例 #4
0
ファイル: nnet_helpers.py プロジェクト: odol0503/mss_pytorch
def test_eval(nnet, B, T, N, L, wsz, hop):
    """
        Method to test the model on the test data. Writes the outcomes in ".wav" format and.
        stores them under the defined results path. Optionally, it performs BSS-Eval using
        MIREval python toolbox (Used only for comparison to BSSEval Matlab implementation).
        The evaluation results are stored under the defined save path.
        Args:
            nnet             : (List)      A list containing the Pytorch modules of the skip-filtering model.
            B                : (int)       Batch size.
            T                : (int)       Length of the time-sequence.
            N                : (int)       The FFT size.
            L                : (int)       Number of context frames from the time-sequence.
            wsz              : (int)       Window size in samples.
            hop              : (int)       Hop size in samples.
    """
    nnet[0].eval()
    nnet[1].eval()
    nnet[2].eval()
    nnet[3].eval()

    def my_res(mx, vx, L, wsz):
        """
            A helper function to reshape data according
            to the context frame.
        """
        mx = np.ascontiguousarray(mx[:, L:-L, :], dtype=np.float32)
        mx.shape = (mx.shape[0] * mx.shape[1], wsz)
        vx = np.ascontiguousarray(vx[:, L:-L, :], dtype=np.float32)
        vx.shape = (vx.shape[0] * vx.shape[1], wsz)

        return mx, vx

    # Paths for loading and storing the test-set
    # Generate full paths for test
    test_sources_list = sorted(os.listdir(sources_path + foldersList[1]))
    test_sources_list = [
        sources_path + foldersList[1] + '/' + i for i in test_sources_list
    ]

    # Initializing the containers of the metrics
    sdr = []
    sir = []
    sar = []

    for indx in xrange(len(test_sources_list)):
        print('Reading:' + test_sources_list[indx])
        # Reading
        bass, _ = Io.wavRead(os.path.join(test_sources_list[indx],
                                          keywords[0]),
                             mono=False)
        drums, _ = Io.wavRead(os.path.join(test_sources_list[indx],
                                           keywords[1]),
                              mono=False)
        oth, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[2]),
                            mono=False)
        vox, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[3]),
                            mono=False)

        bk_true = np.sum(bass + drums + oth, axis=-1) * 0.5
        mix = np.sum(bass + drums + oth + vox, axis=-1) * 0.5
        sv_true = np.sum(vox, axis=-1) * 0.5

        # STFT Analysing
        mx, px = tf.TimeFrequencyDecomposition.STFT(mix, tf.hamming(wsz, True),
                                                    N, hop)

        # Data reshaping (magnitude and phase)
        mx, px, _ = prepare_overlap_sequences(mx, px, px, T, 2 * L, B)

        # The actual "denoising" part
        vx_hat = np.zeros((mx.shape[0], T - L * 2, wsz), dtype=np.float32)

        for batch in xrange(mx.shape[0] / B):
            H_enc = nnet[0](mx[batch * B:(batch + 1) * B, :, :])

            H_j_dec = it_infer.iterative_recurrent_inference(nnet[1],
                                                             H_enc,
                                                             criterion=None,
                                                             tol=1e-3,
                                                             max_iter=10)

            vs_hat, mask = nnet[2](H_j_dec,
                                   mx[batch * B:(batch + 1) * B, :, :])
            y_out = nnet[3](vs_hat)
            vx_hat[batch * B:(batch + 1) * B, :, :] = y_out.data.cpu().numpy()

        # Final reshaping
        vx_hat.shape = (vx_hat.shape[0] * vx_hat.shape[1], wsz)
        mx, px = my_res(mx, px, L, wsz)

        # Time-domain recovery
        # Iterative G-L algorithm
        for GLiter in range(10):
            sv_hat = tf.TimeFrequencyDecomposition.iSTFT(
                vx_hat, px, wsz, hop, True)
            _, px = tf.TimeFrequencyDecomposition.STFT(sv_hat,
                                                       tf.hamming(wsz, True),
                                                       N, hop)

        # Removing the samples that no estimation exists
        mix = mix[L * hop:]
        sv_true = sv_true[L * hop:]
        bk_true = bk_true[L * hop:]

        # Background music estimation
        if len(sv_true) > len(sv_hat):
            bk_hat = mix[:len(sv_hat)] - sv_hat
        else:
            bk_hat = mix - sv_hat[:len(mix)]

        # Disk writing for external BSS_eval using DSD100-tools (used in our paper)
        Io.wavWrite(
            sv_true, 44100, 16,
            os.path.join(save_path, 'tf_true_sv_' + str(indx) + '.wav'))
        Io.wavWrite(
            bk_true, 44100, 16,
            os.path.join(save_path, 'tf_true_bk_' + str(indx) + '.wav'))
        Io.wavWrite(sv_hat, 44100, 16,
                    os.path.join(save_path, 'tf_hat_sv_' + str(indx) + '.wav'))
        Io.wavWrite(bk_hat, 44100, 16,
                    os.path.join(save_path, 'tf_hat_bk_' + str(indx) + '.wav'))
        Io.wavWrite(mix, 44100, 16,
                    os.path.join(save_path, 'tf_mix_' + str(indx) + '.wav'))

        # Internal BSSEval using librosa (just for comparison)
        if len(sv_true) > len(sv_hat):
            c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise(
                [sv_true[:len(sv_hat)], bk_true[:len(sv_hat)]],
                [sv_hat, bk_hat])
        else:
            c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise(
                [sv_true, bk_true],
                [sv_hat[:len(sv_true)], bk_hat[:len(sv_true)]])

        sdr.append(c_sdr)
        sir.append(c_sir)
        sar.append(c_sar)

        # Storing the results iteratively
        pickle.dump(sdr, open(os.path.join(save_path, 'SDR.p'), 'wb'))
        pickle.dump(sir, open(os.path.join(save_path, 'SIR.p'), 'wb'))
        pickle.dump(sar, open(os.path.join(save_path, 'SAR.p'), 'wb'))

    return None