def overfit_model(scaper_folder): nussl.utils.seed(0) tfms = datasets.transforms.Compose([ datasets.transforms.PhaseSensitiveSpectrumApproximation(), datasets.transforms.MagnitudeWeights(), datasets.transforms.ToSeparationModel(), datasets.transforms.GetExcerpt(100) ]) dataset = datasets.Scaper(scaper_folder, transform=tfms) dataset.items = [dataset.items[5]] dataloader = torch.utils.data.DataLoader(dataset) n_features = dataset[0]['mix_magnitude'].shape[1] config = ml.networks.builders.build_recurrent_chimera( n_features, 50, 1, True, 0.3, 20, 'sigmoid', 2, 'sigmoid', normalization_class='InstanceNorm') model = ml.SeparationModel(config) model = model.to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=1e-2) loss_dictionary = { 'DeepClusteringLoss': { 'weight': 0.2 }, 'PermutationInvariantLoss': { 'args': ['L1Loss'], 'weight': 0.8 } } train_closure = ml.train.closures.TrainClosure(loss_dictionary, optimizer, model) trainer, _ = ml.train.create_train_and_validation_engines(train_closure, device=DEVICE) with tempfile.TemporaryDirectory() as tmpdir: _dir = fix_dir if fix_dir else tmpdir ml.train.add_stdout_handler(trainer) ml.train.add_validate_and_checkpoint(_dir, model, optimizer, dataset, trainer) trainer.run(dataloader, max_epochs=1, epoch_length=EPOCH_LENGTH) model_path = os.path.join(trainer.state.output_folder, 'checkpoints', 'best.model.pth') yield model_path, dataset.process_item(dataset.items[0])
def overfit_audio_model(scaper_folder): nussl.utils.seed(0) tfms = datasets.transforms.Compose([ datasets.transforms.GetAudio(), datasets.transforms.ToSeparationModel(), datasets.transforms.GetExcerpt( 32000, time_dim=1, tf_keys=['mix_audio', 'source_audio']) ]) dataset = datasets.Scaper( scaper_folder, transform=tfms) dataset.items = [dataset.items[5]] dataloader = torch.utils.data.DataLoader( dataset, batch_size=1) config = ml.networks.builders.build_recurrent_end_to_end( 256, 256, 64, 'sqrt_hann', 50, 2, True, 0.3, 2, 'sigmoid', num_audio_channels=1, mask_complex=False, rnn_type='lstm', mix_key='mix_audio') model = ml.SeparationModel(config) model = model.to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=1e-3) loss_dictionary = { 'PermutationInvariantLoss': { 'args': ['SISDRLoss'], 'weight': 1.0, 'keys': {'audio': 'estimates', 'source_audio': 'targets'} } } train_closure = ml.train.closures.TrainClosure( loss_dictionary, optimizer, model) trainer, _ = ml.train.create_train_and_validation_engines( train_closure, device=DEVICE ) with tempfile.TemporaryDirectory() as tmpdir: _dir = os.path.join(fix_dir, 'dae') if fix_dir else tmpdir ml.train.add_stdout_handler(trainer) ml.train.add_validate_and_checkpoint( _dir, model, optimizer, dataset, trainer) ml.train.add_progress_bar_handler(trainer) trainer.run(dataloader, max_epochs=1, epoch_length=EPOCH_LENGTH) model_path = os.path.join( trainer.state.output_folder, 'checkpoints', 'best.model.pth') yield model_path, dataset.process_item(dataset.items[0])
def test_clustering_separation_base(scaper_folder, monkeypatch): dataset = datasets.Scaper(scaper_folder) item = dataset[5] mix = item['mix'] sources = item['sources'] pytest.raises(SeparationException, separation.ClusteringSeparationBase, mix, 2, clustering_type='not allowed') clustering_types = ( separation.base.clustering_separation_base.ALLOWED_CLUSTERING_TYPES) separator = separation.ClusteringSeparationBase(mix, 2) bad_features = np.ones(100) pytest.raises(SeparationException, separator.run, bad_features) good_features = np.stack([ np.abs(s.stft()) for _, s in sources.items() ], axis=-1) good_features = ( good_features == good_features.max(axis=-1, keepdims=True)) def dummy_extract(self): return good_features monkeypatch.setattr( separation.ClusteringSeparationBase, 'extract_features', dummy_extract) for clustering_type in clustering_types: separator = separation.ClusteringSeparationBase( mix, 2, clustering_type=clustering_type) pytest.raises(SeparationException, separator.confidence) estimates = separator() confidence = separator.confidence() assert confidence == 1.0 evaluator = evaluation.BSSEvalScale( list(sources.values()), estimates, compute_permutation=True) scores = evaluator.evaluate() for key in evaluator.source_labels: for metric in ['SI-SDR', 'SI-SIR']: _score = scores[key][metric] for val in _score: assert val > 5 separator = separation.ClusteringSeparationBase( mix, 2, clustering_type=clustering_type, mask_type='binary') estimates = separator() evaluator = evaluation.BSSEvalScale( list(sources.values()), estimates, compute_permutation=True) scores = evaluator.evaluate() for key in evaluator.source_labels: for metric in ['SI-SDR', 'SI-SIR']: _score = scores[key][metric] for val in _score: assert val > 9
def mix_and_sources(scaper_folder): dataset = datasets.Scaper(scaper_folder) item = dataset[0] return item['mix'], item['sources']