示例#1
0
def data_process_results_testing(index,
                                 voice_true,
                                 bg_true,
                                 voice_predicted,
                                 window_size,
                                 mix,
                                 mix_magnitude,
                                 mix_phase,
                                 hop,
                                 context_length,
                                 output_file_name=None):
    """Calculates SDR and SIR and creates the resulting audio files.

    :param index: The index of the current source/track.
    :type index: int
    :param voice_true: The true voice.
    :type voice_true: numpy.core.multiarray.ndarray
    :param bg_true: The true background music.
    :type bg_true: numpy.core.multiarray.ndarray
    :param voice_predicted: The predicted voice.
    :type voice_predicted: numpy.core.multiarray.ndarray
    :param window_size: The window size in samples.
    :type window_size: int
    :param mix: The mixture.
    :type mix: numpy.core.multiarray.ndarray
    :param mix_magnitude: The mixture magnitude.
    :type mix_magnitude: numpy.core.multiarray.ndarray
    :param mix_phase: The mixture phase.
    :type mix_phase: numpy.core.multiarray.ndarray
    :param hop: The hop size in samples.
    :type hop: int
    :param context_length: The context length in frames.
    :type context_length: int
    :param output_file_name: The output file name for the predicted voice\
                             and background music. If this argument is not
                             None, then the function just synthesizes the
                             voice and the background music, and saves them.
    :type output_file_name: list[str] | None
    :return: The values of SDR and SIR for each of the frames in\
             the current track, for both voice and background music.
    :rtype: (list[numpy.core.multiarray.ndarray], list[numpy.core.multiarray.ndarray])
    """
    voice_predicted.shape = (voice_predicted.shape[0] *
                             voice_predicted.shape[1], window_size)
    mix_magnitude, mix_phase = _context_based_reshaping(
        mix_magnitude, mix_phase, context_length, window_size)

    voice_hat = i_stft(voice_predicted, mix_phase, window_size, hop)

    # Removing the samples that no estimation exists
    mix = mix[context_length * hop:]
    if output_file_name is None:
        voice_true = voice_true[context_length * hop:]
        bg_true = bg_true[context_length * hop:]
        min_len = min(len(voice_true), len(voice_hat))
        example_index = index + 1
    else:
        voice_true = None
        bg_true = None
        example_index = None
        min_len = min(len(mix), len(voice_hat))

    # Background music estimation
    bg_hat = mix[:min_len] - voice_hat[:min_len]

    if output_file_name is None:
        voice_hat_path = output_audio_paths['voice_predicted'].format(
            p=example_index)
        bg_hat_path = output_audio_paths['bg_predicted'].format(
            p=example_index)
        wav_write(
            voice_true,
            file_name=output_audio_paths['voice_true'].format(p=example_index),
            **wav_quality)
        wav_write(
            bg_true,
            file_name=output_audio_paths['bg_true'].format(p=example_index),
            **wav_quality)
        wav_write(mix,
                  file_name=output_audio_paths['mix'].format(p=example_index),
                  **wav_quality)

        # Metrics calculation
        sdr, sir = _get_me_the_metrics(
            bss_eval.bss_eval_images_framewise(
                [voice_true[:min_len], bg_true[:min_len]],
                [voice_hat[:min_len], bg_hat[:min_len]]))

    else:
        voice_hat_path = output_file_name[0]
        bg_hat_path = output_file_name[1]

        sdr = None
        sir = None

    wav_write(voice_hat, file_name=voice_hat_path, **wav_quality)
    wav_write(bg_hat, file_name=bg_hat_path, **wav_quality)

    return sdr, sir
示例#2
0
def test_eval(nnet, B, T, N, L, wsz, hop):
    """
        Method to test the model on the test data. Writes the outcomes in ".wav" format and.
        stores them under the defined results path. Optionally, it performs BSS-Eval using
        MIREval python toolbox (Used only for comparison to BSSEval Matlab implementation).
        The evaluation results are stored under the defined save path.
        Args:
            nnet             : (List)      A list containing the Pytorch modules of the skip-filtering model.
            B                : (int)       Batch size.
            T                : (int)       Length of the time-sequence.
            N                : (int)       The FFT size.
            L                : (int)       Number of context frames from the time-sequence.
            wsz              : (int)       Window size in samples.
            hop              : (int)       Hop size in samples.
    """
    nnet[0].eval()
    nnet[1].eval()
    nnet[2].eval()
    nnet[3].eval()

    def my_res(mx, vx, L, wsz):
        """
            A helper function to reshape data according
            to the context frame.
        """
        mx = np.ascontiguousarray(mx[:, L:-L, :], dtype=np.float32)
        mx.shape = (mx.shape[0] * mx.shape[1], wsz)
        vx = np.ascontiguousarray(vx[:, L:-L, :], dtype=np.float32)
        vx.shape = (vx.shape[0] * vx.shape[1], wsz)

        return mx, vx

    # Paths for loading and storing the test-set
    # Generate full paths for test
    test_sources_list = sorted(os.listdir(sources_path + foldersList[1]))
    test_sources_list = [
        sources_path + foldersList[1] + '/' + i for i in test_sources_list
    ]

    # Initializing the containers of the metrics
    sdr = []
    sir = []
    sar = []

    for indx in xrange(len(test_sources_list)):
        print('Reading:' + test_sources_list[indx])
        # Reading
        bass, _ = Io.wavRead(os.path.join(test_sources_list[indx],
                                          keywords[0]),
                             mono=False)
        drums, _ = Io.wavRead(os.path.join(test_sources_list[indx],
                                           keywords[1]),
                              mono=False)
        oth, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[2]),
                            mono=False)
        vox, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[3]),
                            mono=False)

        bk_true = np.sum(bass + drums + oth, axis=-1) * 0.5
        mix = np.sum(bass + drums + oth + vox, axis=-1) * 0.5
        sv_true = np.sum(vox, axis=-1) * 0.5

        # STFT Analysing
        mx, px = tf.TimeFrequencyDecomposition.STFT(mix, tf.hamming(wsz, True),
                                                    N, hop)

        # Data reshaping (magnitude and phase)
        mx, px, _ = prepare_overlap_sequences(mx, px, px, T, 2 * L, B)

        # The actual "denoising" part
        vx_hat = np.zeros((mx.shape[0], T - L * 2, wsz), dtype=np.float32)

        for batch in xrange(mx.shape[0] / B):
            H_enc = nnet[0](mx[batch * B:(batch + 1) * B, :, :])

            H_j_dec = it_infer.iterative_recurrent_inference(nnet[1],
                                                             H_enc,
                                                             criterion=None,
                                                             tol=1e-3,
                                                             max_iter=10)

            vs_hat, mask = nnet[2](H_j_dec,
                                   mx[batch * B:(batch + 1) * B, :, :])
            y_out = nnet[3](vs_hat)
            vx_hat[batch * B:(batch + 1) * B, :, :] = y_out.data.cpu().numpy()

        # Final reshaping
        vx_hat.shape = (vx_hat.shape[0] * vx_hat.shape[1], wsz)
        mx, px = my_res(mx, px, L, wsz)

        # Time-domain recovery
        # Iterative G-L algorithm
        for GLiter in range(10):
            sv_hat = tf.TimeFrequencyDecomposition.iSTFT(
                vx_hat, px, wsz, hop, True)
            _, px = tf.TimeFrequencyDecomposition.STFT(sv_hat,
                                                       tf.hamming(wsz, True),
                                                       N, hop)

        # Removing the samples that no estimation exists
        mix = mix[L * hop:]
        sv_true = sv_true[L * hop:]
        bk_true = bk_true[L * hop:]

        # Background music estimation
        if len(sv_true) > len(sv_hat):
            bk_hat = mix[:len(sv_hat)] - sv_hat
        else:
            bk_hat = mix - sv_hat[:len(mix)]

        # Disk writing for external BSS_eval using DSD100-tools (used in our paper)
        Io.wavWrite(
            sv_true, 44100, 16,
            os.path.join(save_path, 'tf_true_sv_' + str(indx) + '.wav'))
        Io.wavWrite(
            bk_true, 44100, 16,
            os.path.join(save_path, 'tf_true_bk_' + str(indx) + '.wav'))
        Io.wavWrite(sv_hat, 44100, 16,
                    os.path.join(save_path, 'tf_hat_sv_' + str(indx) + '.wav'))
        Io.wavWrite(bk_hat, 44100, 16,
                    os.path.join(save_path, 'tf_hat_bk_' + str(indx) + '.wav'))
        Io.wavWrite(mix, 44100, 16,
                    os.path.join(save_path, 'tf_mix_' + str(indx) + '.wav'))

        # Internal BSSEval using librosa (just for comparison)
        if len(sv_true) > len(sv_hat):
            c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise(
                [sv_true[:len(sv_hat)], bk_true[:len(sv_hat)]],
                [sv_hat, bk_hat])
        else:
            c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise(
                [sv_true, bk_true],
                [sv_hat[:len(sv_true)], bk_hat[:len(sv_true)]])

        sdr.append(c_sdr)
        sir.append(c_sir)
        sar.append(c_sar)

        # Storing the results iteratively
        pickle.dump(sdr, open(os.path.join(save_path, 'SDR.p'), 'wb'))
        pickle.dump(sir, open(os.path.join(save_path, 'SIR.p'), 'wb'))
        pickle.dump(sar, open(os.path.join(save_path, 'SAR.p'), 'wb'))

    return None
            IO.AudioIO.audioWrite(
                svhat, fs, 16,
                os.path.join(savepath, 'svhat_' + str(fileIndx) + '.m4a'),
                'm4a')  # Use wavWrite and '.wav' for Matlab-based evaluation
            IO.AudioIO.audioWrite(
                bkhat, fs, 16,
                os.path.join(savepath, 'bkhat_' + str(fileIndx) + '.m4a'),
                'm4a')  # Use wavWrite and '.wav' for Matlab-based evaluation
            IO.AudioIO.audioWrite(
                xsv, fs, 16,
                os.path.join(savepath, 'svtrue_' + str(fileIndx) + '.m4a'),
                'm4a')  # Use wavWrite and '.wav' for Matlab-based evaluation
            IO.AudioIO.audioWrite(
                xbk, fs, 16,
                os.path.join(savepath, 'bktrue_' + str(fileIndx) + '.m4a'),
                'm4a')  # Use wavWrite and '.wav' for Matlab-based evaluation

            # In case that evaluation takes place in python (Matlab BSSEval-images was used for the paper)
            print('Evaluating')
            cSDR, cISR, cSIR, cSAR, _ = bssEval.bss_eval_images_framewise(
                [xsv, xbk], [svhat, bkhat])
            SDR.append(cSDR)
            ISR.append(cISR)
            SIR.append(cSIR)
            SAR.append(cSAR)

            # Saving Results
            pickle.dump(SDR, open(os.path.join(savepath, 'SDR.p'), 'wb'))
            pickle.dump(ISR, open(os.path.join(savepath, 'ISR.p'), 'wb'))
            pickle.dump(SIR, open(os.path.join(savepath, 'SIR.p'), 'wb'))
            pickle.dump(SAR, open(os.path.join(savepath, 'SAR.p'), 'wb'))