def bss_evaluate(model_config, dataset, load_model): ''' Calculates source separation evaluation metrics of a given separator model on the test set using BSS-Eval :param model_config: Separation network configuration required to build symbolic computation graph of network :param dataset: Test dataset :param load_model: Path to separator model checkpoint containing the network weights :return: Dict containing evaluation metrics, wav files of predicted drum and acc are written to results directory ''' # Determine input and output shapes, if we use U-net as separator track_number = 1 freq_bins = model_config["num_fft"] / 2 + 1 # Make even number of freq bins disc_input_shape = [1, freq_bins - 1, model_config["num_frames"], 1] # Shape of discriminator input separator_class = Models.Unet.Unet(model_config["num_layers"]) sep_input_shape, sep_output_shape = separator_class.getUnetPadding( np.array(disc_input_shape)) separator_func = separator_class.get_output # Placeholders and input normalisation input_ph, queue, [mix_context, acc, drums] = Input.get_multitrack_input( sep_output_shape[1:], 1, name="input_batch", input_shape=sep_input_shape[1:]) mix = Input.crop(mix_context, sep_output_shape) mix_norm, mix_context_norm, acc_norm, drums_norm = Input.norm(mix), Input.norm(mix_context), \ Input.norm(acc), Input.norm(drums) print("Testing...") # BUILD MODELS # Separator separator_acc_norm, separator_drums_norm = separator_func(mix_context_norm, reuse=False) separator_acc, separator_drums = Input.denorm( separator_acc_norm), Input.denorm(separator_drums_norm) # Start session and queue input threads sess = tf.Session() sess.run(tf.global_variables_initializer()) # Load model # Load pretrained model to continue training, if we are supposed to restorer = tf.train.Saver(None, write_version=tf.train.SaverDef.V2) print("Num of variables" + str(len(tf.global_variables()))) restorer.restore(sess, load_model) print('Pre-trained model restored for testing') # Initialize total score object song_scores = list() for multitrack in dataset[8:15]: filename = multitrack[0].path print("Evaluating file: " + filename + "Track number: " + track_number) if filename.__contains__("DSD100"): db = "DSD100" elif filename.__contains__("Kala"): db = "IKala" elif filename.__contains__("ccmixter"): db = "CCMixter" elif filename.__contains__("MedleyDB"): db = "MedleyDB" else: db = 'musdb18orNIStems' song_info = {"Title": filename, "Database": db} # Load mixture and pad it so that output sources have the same length after STFT/ISTFT mix_audio, mix_sr = librosa.load(multitrack[0].path, sr=model_config["expected_sr"]) mix_length = len(mix_audio) # Pad input so that ISTFT later leads to same-length audio mix_audio_pad = librosa.util.fix_length( mix_audio, mix_length + model_config["num_fft"] // 2) mix_mag, mix_ph = Input.audioFileToSpectrogram(mix_audio_pad, model_config["num_fft"], model_config["num_hop"]) source_time_frames = mix_mag.shape[1] # Preallocate source predictions (same shape as input mixture) acc_pred_mag = np.zeros(mix_mag.shape, np.float32) drums_pred_mag = np.zeros(mix_mag.shape, np.float32) input_time_frames = sep_input_shape[2] output_time_frames = sep_output_shape[2] # Pad mix spectrogram across time at beg and end so network can make prediction at the beginning and end of signal pad_time_frames = (input_time_frames - output_time_frames) / 2 mix_mag = np.pad(mix_mag, [(0, 0), (pad_time_frames, pad_time_frames)], mode="constant", constant_values=0.0) # Iterate over mixture magnitudes, fetch network prediction for source_pos in range(0, source_time_frames, output_time_frames): # If output patch reaches over the end of source spectrogram, set it so we predict at end of the output, then stop if source_pos + output_time_frames > source_time_frames: source_pos = source_time_frames - output_time_frames # Prepare mixture excerpt by selecting time interval mix_mag_part = mix_mag[:, source_pos:source_pos + input_time_frames] mix_mag_part = Utils.pad_freqs( mix_mag_part, sep_input_shape[1:3]) # Pad along frequency axis mix_mag_part = mix_mag_part[np.newaxis, :, :, np.newaxis] # Fetch network prediction acc_mag_part, drums_mag_part = sess.run( [separator_acc, separator_drums], feed_dict={mix_context: mix_mag_part}) # Save predictions #source_shape = [1, freq_bins, acc_mag_part.shape[2], 1] acc_pred_mag[:, source_pos:source_pos + output_time_frames] = acc_mag_part[0, :-1, :, 0] drums_pred_mag[:, source_pos:source_pos + output_time_frames] = drums_mag_part[0, :-1, :, 0] # Spectrograms to audio, using mixture phase acc_pred_audio = Input.spectrogramToAudioFile(acc_pred_mag, model_config["num_fft"], model_config["num_hop"], phase=mix_ph, length=mix_length, phaseIterations=0) drums_pred_audio = Input.spectrogramToAudioFile( drums_pred_mag, model_config["num_fft"], model_config["num_hop"], phase=mix_ph, length=mix_length, phaseIterations=0) # Load original sources if isinstance(multitrack[1], float): acc_audio = np.zeros(mix_audio.shape, np.float32) else: acc_audio, _ = librosa.load(multitrack[1].path, sr=model_config["expected_sr"]) if isinstance(multitrack[2], float): drum_audio = np.zeros(mix_audio.shape, np.float32) else: drum_audio, _ = librosa.load(multitrack[2].path, sr=model_config["expected_sr"]) # Check if any reference source is completely silent, if so, inject some very slight noise to avoid problems during SDR reference_zero = False if np.max(np.abs(acc_audio)) == 0.0: acc_audio += np.random.uniform(-1e-10, 1e-10, size=acc_audio.shape) reference_zero = True if np.max(np.abs(drum_audio)) == 0.0: drum_audio += np.random.uniform(-1e-10, 1e-10, size=drum_audio.shape) reference_zero = True # Evaluate BSS according to MIREX separation method # http://www.music-ir.org/mirex/wiki/2016:Singing_Voice_Separation #/ np.linalg.norm(acc_audio + drum_audio) # Normalized audio ref_sources = np.vstack([acc_audio, drum_audio]) #/ np.linalg.norm(acc_pred_audio + drums_pred_audio) # Normalized estimates pred_sources = np.vstack([acc_pred_audio, drums_pred_audio]) validate(ref_sources, pred_sources) scores = bss_eval_sources(ref_sources, pred_sources, compute_permutation=False) song_info["SDR"] = scores[0] song_info["SIR"] = scores[1] song_info["SAR"] = scores[2] # Compute reference scores and SNR only if both sources are not silent, since they are undefined otherwise if not reference_zero: mix_ref = np.vstack([mix_audio, mix_audio ]) #/ np.linalg.norm(mix_audio + mix_audio) mix_scores = bss_eval_sources(ref_sources, mix_ref, compute_permutation=False) norm_scores = np.array(scores) - np.array(mix_scores) # Compute SNR: 10 log_10 ( ||s_target||^2 / ||s_target - alpha * s_estimate||^2 ), scale target for optimal SNR drums_snr = alpha_snr(drum_audio, drums_pred_audio) acc_snr = alpha_snr(acc_audio, acc_pred_audio) drums_ref_snr = alpha_snr(drum_audio, mix_audio) acc_ref_snr = alpha_snr(acc_audio, mix_audio) song_info["NSDR"] = norm_scores[0] song_info["NSIR"] = norm_scores[1] song_info["SNR"] = np.array([acc_snr, drums_snr]) song_info["NSNR"] = np.array( [acc_snr - acc_ref_snr, drums_snr - drums_ref_snr]) song_scores.append(song_info) print(song_info) dirpath = os.path.join(os.getcwd(), 'results') try: fn = dirpath + "tr_" + track_number + "_drums.wav" librosa.output.write_wav(fn, drums_pred_audio, sr=model_config["expected_sr"]) fn = dirpath + "tr_" + track_number + "_acc.wav" librosa.output.write_wav(fn, acc_pred_audio, sr=model_config["expected_sr"]) except Exception as e: print("Failed to write wav files, error: " + str(e)) track_number = track_number + 1 try: with open(dirpath + '/' + str(experiment_id) + "BSS_eval.pkl", "wb") as file: pickle.dump(song_scores, file) except Exception as e: print("Failed to write score files, error: " + str(e)) #Close session, clear computational graph sess.close() tf.reset_default_graph() return song_scores