示例#1
0
def bss_evaluate(model_config, dataset, load_model):
    '''
    Calculates source separation evaluation metrics of a given separator model on the test set using BSS-Eval
    :param model_config: Separation network configuration required to build symbolic computation graph of network
    :param dataset: Test dataset
    :param load_model: Path to separator model checkpoint containing the network weights
    :return: Dict containing evaluation metrics, wav files of predicted drum and acc are written to results directory
    '''
    # Determine input and output shapes, if we use U-net as separator
    track_number = 1
    freq_bins = model_config["num_fft"] / 2 + 1  # Make even number of freq bins
    disc_input_shape = [1, freq_bins - 1, model_config["num_frames"],
                        1]  # Shape of discriminator input

    separator_class = Models.Unet.Unet(model_config["num_layers"])
    sep_input_shape, sep_output_shape = separator_class.getUnetPadding(
        np.array(disc_input_shape))
    separator_func = separator_class.get_output

    # Placeholders and input normalisation
    input_ph, queue, [mix_context, acc, drums] = Input.get_multitrack_input(
        sep_output_shape[1:],
        1,
        name="input_batch",
        input_shape=sep_input_shape[1:])

    mix = Input.crop(mix_context, sep_output_shape)
    mix_norm, mix_context_norm, acc_norm, drums_norm = Input.norm(mix), Input.norm(mix_context), \
                                                        Input.norm(acc), Input.norm(drums)

    print("Testing...")

    # BUILD MODELS
    # Separator
    separator_acc_norm, separator_drums_norm = separator_func(mix_context_norm,
                                                              reuse=False)
    separator_acc, separator_drums = Input.denorm(
        separator_acc_norm), Input.denorm(separator_drums_norm)

    # Start session and queue input threads
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Load model
    # Load pretrained model to continue training, if we are supposed to
    restorer = tf.train.Saver(None, write_version=tf.train.SaverDef.V2)
    print("Num of variables" + str(len(tf.global_variables())))
    restorer.restore(sess, load_model)
    print('Pre-trained model restored for testing')

    # Initialize total score object
    song_scores = list()

    for multitrack in dataset[8:15]:
        filename = multitrack[0].path
        print("Evaluating file: " + filename + "Track number: " + track_number)
        if filename.__contains__("DSD100"):
            db = "DSD100"
        elif filename.__contains__("Kala"):
            db = "IKala"
        elif filename.__contains__("ccmixter"):
            db = "CCMixter"
        elif filename.__contains__("MedleyDB"):
            db = "MedleyDB"
        else:
            db = 'musdb18orNIStems'
        song_info = {"Title": filename, "Database": db}

        # Load mixture and pad it so that output sources have the same length after STFT/ISTFT
        mix_audio, mix_sr = librosa.load(multitrack[0].path,
                                         sr=model_config["expected_sr"])
        mix_length = len(mix_audio)
        # Pad input so that ISTFT later leads to same-length audio
        mix_audio_pad = librosa.util.fix_length(
            mix_audio, mix_length + model_config["num_fft"] // 2)
        mix_mag, mix_ph = Input.audioFileToSpectrogram(mix_audio_pad,
                                                       model_config["num_fft"],
                                                       model_config["num_hop"])
        source_time_frames = mix_mag.shape[1]

        # Preallocate source predictions (same shape as input mixture)
        acc_pred_mag = np.zeros(mix_mag.shape, np.float32)
        drums_pred_mag = np.zeros(mix_mag.shape, np.float32)

        input_time_frames = sep_input_shape[2]
        output_time_frames = sep_output_shape[2]

        # Pad mix spectrogram across time at beg and end so network can make prediction at the beginning and end of signal
        pad_time_frames = (input_time_frames - output_time_frames) / 2
        mix_mag = np.pad(mix_mag, [(0, 0), (pad_time_frames, pad_time_frames)],
                         mode="constant",
                         constant_values=0.0)

        # Iterate over mixture magnitudes, fetch network prediction
        for source_pos in range(0, source_time_frames, output_time_frames):
            # If output patch reaches over the end of source spectrogram, set it so we predict at end of the output, then stop
            if source_pos + output_time_frames > source_time_frames:
                source_pos = source_time_frames - output_time_frames

            # Prepare mixture excerpt by selecting time interval
            mix_mag_part = mix_mag[:,
                                   source_pos:source_pos + input_time_frames]
            mix_mag_part = Utils.pad_freqs(
                mix_mag_part, sep_input_shape[1:3])  # Pad along frequency axis
            mix_mag_part = mix_mag_part[np.newaxis, :, :, np.newaxis]

            # Fetch network prediction
            acc_mag_part, drums_mag_part = sess.run(
                [separator_acc, separator_drums],
                feed_dict={mix_context: mix_mag_part})

            # Save predictions
            #source_shape = [1, freq_bins, acc_mag_part.shape[2], 1]
            acc_pred_mag[:, source_pos:source_pos +
                         output_time_frames] = acc_mag_part[0, :-1, :, 0]
            drums_pred_mag[:, source_pos:source_pos +
                           output_time_frames] = drums_mag_part[0, :-1, :, 0]

        # Spectrograms to audio, using mixture phase
        acc_pred_audio = Input.spectrogramToAudioFile(acc_pred_mag,
                                                      model_config["num_fft"],
                                                      model_config["num_hop"],
                                                      phase=mix_ph,
                                                      length=mix_length,
                                                      phaseIterations=0)

        drums_pred_audio = Input.spectrogramToAudioFile(
            drums_pred_mag,
            model_config["num_fft"],
            model_config["num_hop"],
            phase=mix_ph,
            length=mix_length,
            phaseIterations=0)

        # Load original sources
        if isinstance(multitrack[1], float):
            acc_audio = np.zeros(mix_audio.shape, np.float32)
        else:
            acc_audio, _ = librosa.load(multitrack[1].path,
                                        sr=model_config["expected_sr"])
        if isinstance(multitrack[2], float):
            drum_audio = np.zeros(mix_audio.shape, np.float32)
        else:
            drum_audio, _ = librosa.load(multitrack[2].path,
                                         sr=model_config["expected_sr"])

        # Check if any reference source is completely silent, if so, inject some very slight noise to avoid problems during SDR
        reference_zero = False
        if np.max(np.abs(acc_audio)) == 0.0:
            acc_audio += np.random.uniform(-1e-10, 1e-10, size=acc_audio.shape)
            reference_zero = True
        if np.max(np.abs(drum_audio)) == 0.0:
            drum_audio += np.random.uniform(-1e-10,
                                            1e-10,
                                            size=drum_audio.shape)
            reference_zero = True

        # Evaluate BSS according to MIREX separation method # http://www.music-ir.org/mirex/wiki/2016:Singing_Voice_Separation
        #/ np.linalg.norm(acc_audio + drum_audio) # Normalized audio
        ref_sources = np.vstack([acc_audio, drum_audio])
        #/ np.linalg.norm(acc_pred_audio + drums_pred_audio) # Normalized estimates
        pred_sources = np.vstack([acc_pred_audio, drums_pred_audio])
        validate(ref_sources, pred_sources)
        scores = bss_eval_sources(ref_sources,
                                  pred_sources,
                                  compute_permutation=False)

        song_info["SDR"] = scores[0]
        song_info["SIR"] = scores[1]
        song_info["SAR"] = scores[2]

        # Compute reference scores and SNR only if both sources are not silent, since they are undefined otherwise
        if not reference_zero:
            mix_ref = np.vstack([mix_audio, mix_audio
                                 ])  #/ np.linalg.norm(mix_audio + mix_audio)
            mix_scores = bss_eval_sources(ref_sources,
                                          mix_ref,
                                          compute_permutation=False)
            norm_scores = np.array(scores) - np.array(mix_scores)

            # Compute SNR: 10 log_10 ( ||s_target||^2 / ||s_target - alpha * s_estimate||^2 ), scale target for optimal SNR
            drums_snr = alpha_snr(drum_audio, drums_pred_audio)
            acc_snr = alpha_snr(acc_audio, acc_pred_audio)
            drums_ref_snr = alpha_snr(drum_audio, mix_audio)
            acc_ref_snr = alpha_snr(acc_audio, mix_audio)

            song_info["NSDR"] = norm_scores[0]
            song_info["NSIR"] = norm_scores[1]
            song_info["SNR"] = np.array([acc_snr, drums_snr])
            song_info["NSNR"] = np.array(
                [acc_snr - acc_ref_snr, drums_snr - drums_ref_snr])

        song_scores.append(song_info)
        print(song_info)
        dirpath = os.path.join(os.getcwd(), 'results')
        try:
            fn = dirpath + "tr_" + track_number + "_drums.wav"
            librosa.output.write_wav(fn,
                                     drums_pred_audio,
                                     sr=model_config["expected_sr"])
            fn = dirpath + "tr_" + track_number + "_acc.wav"
            librosa.output.write_wav(fn,
                                     acc_pred_audio,
                                     sr=model_config["expected_sr"])
        except Exception as e:
            print("Failed to write wav files, error: " + str(e))

        track_number = track_number + 1
    try:
        with open(dirpath + '/' + str(experiment_id) + "BSS_eval.pkl",
                  "wb") as file:
            pickle.dump(song_scores, file)
    except Exception as e:
        print("Failed to write score files, error: " + str(e))

    #Close session, clear computational graph
    sess.close()
    tf.reset_default_graph()

    return song_scores