def test_SumAllComponents(self): factorization = SpleeterFactorization(self.dp, n_temporal_segments=1, composition_fn=None, model_name='spleeter:5stems', spleeter_sources_path=spleeter_sources_path) all_components = factorization.compose_model_input() self.assertTrue(np.allclose(all_components, self.reference, atol=10**5))
def test_TemporalSegmentation(self): n_segments = 7 factorization = SpleeterFactorization(self.dp, n_temporal_segments=n_segments, composition_fn=None, model_name='spleeter:5stems', spleeter_sources_path=spleeter_sources_path) all_components = factorization.compose_model_input() leng = len(all_components) # to deal with ignored samples at the end self.assertTrue(np.allclose(all_components, self.reference[:leng], atol=10 ** 5)) self.assertEqual(n_segments * 5, factorization.get_number_components()) # nr. sources = 5
def test_AnalysisWindow(self): start = 35000 leng = 27333 reference = self.reference[start:start+leng] factorization = SpleeterFactorization(self.dp, n_temporal_segments=1, composition_fn=None, model_name='spleeter:5stems', spleeter_sources_path=spleeter_sources_path) factorization.set_analysis_window(start, leng) all_components = factorization.compose_model_input() self.assertTrue(np.allclose(all_components, reference, atol=10 ** 5))
outputs = model(x) top_tag_per_snippet = torch.argmax(outputs.detach().cpu(), axis=1) print("top_tag_per_snippet", len(top_tag_per_snippet), top_tag_per_snippet) sorted_args = torch.argsort(outputs.detach().cpu().mean(axis=0), descending=True) print([tags[t] for t in sorted_args[0:3]]) top_idx = sorted_args[0].item() sorted_snippets = torch.argsort(outputs[:, top_idx].detach().cpu()).numpy() print("top idx:", top_idx) print("top segments:", sorted_snippets) predict_fn = create_predict_fn(model, config) spleeter_factorization = SpleeterFactorization(data_provider, n_temporal_segments=n_segments, composition_fn=composition_fn, model_name='spleeter:5stems') explainer = lime_audio.LimeAudioExplainer(verbose=True, absolute_feature_sort=False) for sn in range(len(snippet_starts)): if args.use_global_tag: labels = [top_idx] else: snippet_tag = top_tag_per_snippet[sn].item() labels = [snippet_tag] print("processing {}_{}".format(sample, sn)) explanation_name = "{}/{}_cls{}_sntag{}_nc{}_sn{}_seg{}_smp{}_nd{}".format(config.model_type, sample, top_idx, labels[0], config.batch_size, sn, n_segments, num_samples, n_display_components)
from audioLIME.factorization import SpleeterFactorization from audioLIME import lime_audio import soundfile as sf import os from examples.sota_utils import prepare_config, get_predict_fn if __name__ == '__main__': audio_path = '/share/home/verena/data/samples/3_Hop Along-SisterCities.stem.mp4_sn0_original.wav' path_sota = '/home/verena/deployment/sota-music-tagging-models/' config = prepare_config("fcn", 29 * 16000) predict_fn = get_predict_fn(config) data_provider = RawAudioProvider(audio_path) spleeter_factorization = SpleeterFactorization(data_provider, n_temporal_segments=10, composition_fn=None, model_name='spleeter:5stems') explainer = lime_audio.LimeAudioExplainer(verbose=True, absolute_feature_sort=False) explanation = explainer.explain_instance(factorization=spleeter_factorization, predict_fn=predict_fn, top_labels=1, num_samples=16384, batch_size=32 ) label = list(explanation.local_exp.keys())[0] top_components, component_indeces = explanation.get_sorted_components(label, positive_components=True, negative_components=False,