def test_separation_base_interact(mix_source_folder, monkeypatch): dataset = datasets.MixSourceFolder(mix_source_folder) item = dataset[0] mix = item['mix'] def dummy_run(self): return self.audio_signal class DummyGradio(): def __init__(*args, **kwargs): pass def launch(self, *args, **kwargs): pass import gradio monkeypatch.setattr(separation.SeparationBase, 'make_audio_signals', dummy_run) monkeypatch.setattr(separation.SeparationBase, 'run', dummy_run) monkeypatch.setattr(gradio, 'Interface', DummyGradio) separator = separation.SeparationBase(mix) separator.interact() separator.interact(add_residual=True)
def test_cache_dataset(mix_source_folder): with tempfile.TemporaryDirectory() as tmpdir: tfms = datasets.transforms.Compose([ datasets.transforms.PhaseSensitiveSpectrumApproximation(), datasets.transforms.ToSeparationModel(), ]) chc = datasets.transforms.Cache(os.path.join(tmpdir, 'cache'), overwrite=True) # no cache dataset = datasets.MixSourceFolder(mix_source_folder, transform=tfms) outputs_a = [] for i in range(len(dataset)): outputs_a.append(dataset[i]) # now add a cache tfms.transforms.append(chc) dataset = datasets.MixSourceFolder(mix_source_folder, transform=tfms, cache_populated=False) assert (tfms.transforms[-1].cache.nchunks_initialized == 0) ml.train.cache_dataset(dataset) assert (tfms.transforms[-1].cache.nchunks_initialized == len(dataset)) # now make sure the cached stuff matches dataset.cache_populated = True outputs_b = [] for i in range(len(dataset)): outputs_b.append(dataset[i]) for _data_a, _data_b in zip(outputs_a, outputs_b): for key in _data_a: if torch.is_tensor(_data_a[key]): assert torch.allclose(_data_a[key], _data_b[key]) else: assert _data_a[key] == _data_b[key]
def test_separation_base(mix_source_folder, monkeypatch): dataset = datasets.MixSourceFolder(mix_source_folder) item = dataset[0] mix = item['mix'] sources = item['sources'] pytest.warns(UserWarning, separation.SeparationBase, AudioSignal()) pytest.raises(ValueError, separation.SeparationBase, None) separator = separation.SeparationBase(mix) assert separator.sample_rate == mix.sample_rate assert separator.stft_params == mix.stft_params pytest.raises(NotImplementedError, separator.run) pytest.raises(NotImplementedError, separator.make_audio_signals) pytest.raises(NotImplementedError, separator.get_metadata) pytest.raises(NotImplementedError, separator) def dummy_run(self): pass monkeypatch.setattr(separation.SeparationBase, 'run', dummy_run) pytest.raises(NotImplementedError, separator) assert separator.__class__.__name__ in str(separator) assert str(mix) in str(separator) other = separation.SeparationBase(mix) separator.fake_array = np.zeros(100) other.fake_array = np.zeros(100) assert separator == other other.fake_array = np.ones(100) assert separator != other diff_other = separation.SeparationBase(sources['s1']) diff_other.fake_array = np.zeros(100) assert separator != diff_other separator = separation.SeparationBase(mix) assert separator.audio_signal == mix monkeypatch.setattr(separation.SeparationBase, 'make_audio_signals', dummy_run) separator(audio_signal=sources['s1']) assert separator.audio_signal == sources['s1']
def test_cache_dataset_with_dataloader(mix_source_folder): with tempfile.TemporaryDirectory() as tmpdir: tfms = datasets.transforms.Compose([ datasets.transforms.PhaseSensitiveSpectrumApproximation(), datasets.transforms.ToSeparationModel(), datasets.transforms.Cache(os.path.join(tmpdir, 'cache'), overwrite=True), datasets.transforms.GetExcerpt(400) ]) dataset = datasets.MixSourceFolder(mix_source_folder, transform=tfms, cache_populated=False) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=2) ml.train.cache_dataset(dataloader) assert (tfms.transforms[-2].cache.nchunks_initialized == len(dataset))
def test_mask_separation_base(mix_source_folder, monkeypatch): dataset = datasets.MixSourceFolder(mix_source_folder) item = dataset[0] mix = item['mix'] sources = item['sources'] class DummyMask(core.masks.MaskBase): @staticmethod def _validate_mask(mask_): pass pass separator = separation.MaskSeparationBase(mix) assert separator.mask_type == core.masks.SoftMask assert separator.mask_threshold == 0.5 separator = separation.MaskSeparationBase(mix, mask_type=core.masks.SoftMask(mask_shape=(100, 10))) assert separator.mask_type == core.masks.SoftMask separator = separation.MaskSeparationBase(mix, mask_type='binary') assert separator.mask_type == core.masks.BinaryMask separator = separation.MaskSeparationBase(mix, mask_type=core.masks.BinaryMask(mask_shape=(100, 10))) assert separator.mask_type == core.masks.BinaryMask pytest.raises(ValueError, separation.MaskSeparationBase, mix, mask_type=None) pytest.raises(ValueError, separation.MaskSeparationBase, mix, mask_type='invalid') pytest.raises(ValueError, separation.MaskSeparationBase, mix, mask_type=DummyMask(mask_shape=(100, 10))) separator = separation.MaskSeparationBase(mix, mask_threshold=0.2) assert separator.mask_threshold == 0.2 pytest.raises(ValueError, separation.MaskSeparationBase, mix, mask_threshold=1.5) pytest.raises(ValueError, separation.MaskSeparationBase, mix, mask_threshold='not a float') separator = separation.MaskSeparationBase(mix) ones_mask = separator.ones_mask(mix.stft().shape) masked = mix.apply_mask(ones_mask) masked.istft() assert np.allclose(masked.audio_data, mix.audio_data, atol=1e-6) separator = separation.MaskSeparationBase(mix, mask_type='binary') ones_mask = separator.ones_mask(mix.stft().shape) masked = mix.apply_mask(ones_mask) masked.istft() assert np.allclose(masked.audio_data, mix.audio_data, atol=1e-6) separator = separation.MaskSeparationBase(mix) zeros_mask = separator.zeros_mask(mix.stft().shape) masked_zeros = mix.apply_mask(zeros_mask) masked_zeros.istft() assert np.allclose(masked_zeros.audio_data, np.zeros(masked_zeros.audio_data.shape), atol=1e-6) separator = separation.MaskSeparationBase(mix, mask_type='binary') ones_mask = separator.ones_mask(mix.stft().shape) zeros_mask = separator.zeros_mask(mix.stft().shape) masked_ones = mix.apply_mask(ones_mask) masked_ones.istft() assert np.allclose(masked_ones.audio_data, mix.audio_data, atol=1e-6) pytest.raises(SeparationException, separator.make_audio_signals) separator = separation.MaskSeparationBase(mix, mask_type='binary') separator.result_masks = [ones_mask, zeros_mask] estimates = separator.make_audio_signals() for e, s in zip(estimates, [masked_ones, masked_zeros]): assert e == s separator = separation.MaskSeparationBase(mix, mask_type='soft') separator.result_masks = [ones_mask, zeros_mask] pytest.raises(SeparationException, separator.make_audio_signals)
def test_overfit_a(mix_source_folder): tfms = datasets.transforms.Compose([ datasets.transforms.PhaseSensitiveSpectrumApproximation(), datasets.transforms.ToSeparationModel(), datasets.transforms.Cache('~/.nussl/tests/cache', overwrite=True), datasets.transforms.GetExcerpt(400) ]) dataset = datasets.MixSourceFolder(mix_source_folder, transform=tfms) ml.train.cache_dataset(dataset) dataset.cache_populated = True dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=len(dataset), num_workers=2) # create the model, based on the first item in the dataset # second bit of the shape is the number of features n_features = dataset[0]['mix_magnitude'].shape[1] mi_config = ml.networks.builders.build_recurrent_mask_inference( n_features, 50, 1, False, 0.0, 2, 'sigmoid', ) model = ml.SeparationModel(mi_config) device = 'cuda' if torch.cuda.is_available() else 'cpu' if device == 'cuda': epoch_length = 100 else: epoch_length = 10 model = model.to(device) # create optimizer optimizer = optim.Adam(model.parameters(), lr=1e-3) loss_dictionary = {'L1Loss': {'weight': 1.0}} train_closure = ml.train.closures.TrainClosure(loss_dictionary, optimizer, model) val_closure = ml.train.closures.ValidationClosure(loss_dictionary, model) with tempfile.TemporaryDirectory() as tmpdir: _dir = fix_dir if fix_dir else tmpdir os.makedirs(os.path.join(_dir, 'plots'), exist_ok=True) trainer, validator = ml.train.create_train_and_validation_engines( train_closure, val_closure, device=device) # add handlers to engine ml.train.add_stdout_handler(trainer, validator) ml.train.add_validate_and_checkpoint(_dir, model, optimizer, dataset, trainer, val_data=dataloader, validator=validator) ml.train.add_tensorboard_handler(_dir, trainer) # run engine trainer.run(dataloader, max_epochs=5, epoch_length=epoch_length) model_path = os.path.join(trainer.state.output_folder, 'checkpoints', 'best.model.pth') state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) model.load_state_dict(state_dict['state_dict']) history = state_dict['metadata']['trainer.state.epoch_history'] for key in history: plt.figure(figsize=(10, 4)) plt.title(f"epoch:{key}") plt.plot(np.array(history[key]).reshape(-1, )) plt.savefig( os.path.join(trainer.state.output_folder, 'plots', f"epoch:{key.replace('/', ':')}.png"))
def test_create_engine(mix_source_folder): # load dataset with transforms tfms = datasets.transforms.Compose([ datasets.transforms.PhaseSensitiveSpectrumApproximation(), datasets.transforms.ToSeparationModel(), datasets.transforms.Cache(os.path.join(fix_dir, 'cache')) ]) dataset = datasets.MixSourceFolder(mix_source_folder, transform=tfms) # create the model, based on the first item in the dataset # second bit of the shape is the number of features n_features = dataset[0]['mix_magnitude'].shape[1] mi_config = ml.networks.builders.build_recurrent_mask_inference( n_features, 50, 2, True, 0.3, 2, 'softmax', ) model = ml.SeparationModel(mi_config) # create optimizer optimizer = optim.Adam(model.parameters(), lr=1e-3) # dummy function for processing a batch through the model def train_batch(engine, data): loss = -engine.state.iteration return {'loss': loss} # building the training and validation engines and running them # the validation engine runs within the training engine run with tempfile.TemporaryDirectory() as tmpdir: _dir = fix_dir if fix_dir else tmpdir # _dir = tmpdir trainer, validator = ml.train.create_train_and_validation_engines( train_batch, train_batch) # add handlers to engine ml.train.add_stdout_handler(trainer, validator) ml.train.add_validate_and_checkpoint(_dir, model, optimizer, dataset, trainer, dataset, validator, save_by_epoch=1) ml.train.add_tensorboard_handler(_dir, trainer, every_iteration=True) ml.train.add_progress_bar_handler(trainer) # run engine trainer.run(dataset, max_epochs=3) assert os.path.exists(trainer.state.output_folder) assert os.path.exists( os.path.join(trainer.state.output_folder, 'checkpoints', 'latest.model.pth')) assert os.path.exists( os.path.join(trainer.state.output_folder, 'checkpoints', 'best.model.pth')) assert os.path.exists( os.path.join(trainer.state.output_folder, 'checkpoints', 'latest.optimizer.pth')) assert os.path.exists( os.path.join(trainer.state.output_folder, 'checkpoints', 'best.optimizer.pth')) for i in range(1, 4): assert os.path.exists( os.path.join(trainer.state.output_folder, 'checkpoints', f'epoch{i}.model.pth')) assert len(trainer.state.epoch_history['train/loss']) == 3 assert len(trainer.state.iter_history['loss']) == 10 # try resuming model_path = os.path.join(trainer.state.output_folder, 'checkpoints', 'latest.model.pth') optimizer_path = os.path.join(trainer.state.output_folder, 'checkpoints', 'latest.optimizer.pth') opt_state_dict = torch.load(optimizer_path, map_location=lambda storage, loc: storage) state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) optimizer.load_state_dict(opt_state_dict) model.load_state_dict(state_dict['state_dict']) # make sure the cache got removed in saved transforms bc it's not a portable # transform for t in state_dict['metadata']['train_dataset'][ 'transforms'].transforms: assert not isinstance(t, datasets.transforms.Cache) new_trainer, new_validator = ( ml.train.create_train_and_validation_engines(train_batch)) # add handlers to engine ml.train.add_stdout_handler(new_trainer) ml.train.add_validate_and_checkpoint(trainer.state.output_folder, model, optimizer, dataset, new_trainer) ml.train.add_tensorboard_handler(trainer.state.output_folder, new_trainer) new_trainer.load_state_dict( state_dict['metadata']['trainer.state_dict']) assert new_trainer.state.epoch == trainer.state.epoch new_trainer.run(dataset, max_epochs=3)
def test_trainer_data_parallel(mix_source_folder): # load dataset with transforms tfms = datasets.transforms.Compose([ datasets.transforms.PhaseSensitiveSpectrumApproximation(), datasets.transforms.ToSeparationModel() ]) dataset = datasets.MixSourceFolder(mix_source_folder, transform=tfms) # create the model, based on the first item in the dataset # second bit of the shape is the number of features n_features = dataset[0]['mix_magnitude'].shape[1] mi_config = ml.networks.builders.build_recurrent_mask_inference( n_features, 50, 2, True, 0.3, 2, 'softmax', ) model = ml.SeparationModel(mi_config) model = torch.nn.DataParallel(model) # create optimizer optimizer = optim.Adam(model.parameters(), lr=1e-3) # dummy function for processing a batch through the model def train_batch(engine, data): loss = np.random.rand() return {'loss': loss} # building the training and validation engines and running them # the validation engine runs within the training engine run with tempfile.TemporaryDirectory() as tmpdir: _dir = fix_dir if fix_dir else tmpdir trainer, validator = ml.train.create_train_and_validation_engines( train_batch, train_batch) # add handlers to engine ml.train.add_stdout_handler(trainer, validator) ml.train.add_validate_and_checkpoint(_dir, model, optimizer, dataset, trainer, dataset, validator) ml.train.add_tensorboard_handler(_dir, trainer) # run engine trainer.run(dataset, max_epochs=3) assert os.path.exists(trainer.state.output_folder) assert os.path.exists( os.path.join(trainer.state.output_folder, 'checkpoints', 'latest.model.pth')) assert os.path.exists( os.path.join(trainer.state.output_folder, 'checkpoints', 'best.model.pth')) assert os.path.exists( os.path.join(trainer.state.output_folder, 'checkpoints', 'latest.optimizer.pth')) assert os.path.exists( os.path.join(trainer.state.output_folder, 'checkpoints', 'best.optimizer.pth')) assert len(trainer.state.epoch_history['train/loss']) == 3 assert len(trainer.state.iter_history['loss']) == 10
def test_gradients(mix_source_folder): os.makedirs('tests/local/', exist_ok=True) utils.seed(0) tfms = datasets.transforms.Compose([ datasets.transforms.GetAudio(), datasets.transforms.PhaseSensitiveSpectrumApproximation(), datasets.transforms.MagnitudeWeights(), datasets.transforms.ToSeparationModel(), datasets.transforms.GetExcerpt(50), datasets.transforms.GetExcerpt(3136, time_dim=1, tf_keys=['mix_audio', 'source_audio']) ]) dataset = datasets.MixSourceFolder(mix_source_folder, transform=tfms) # create the model, based on the first item in the dataset # second bit of the shape is the number of features n_features = dataset[0]['mix_magnitude'].shape[1] # make some configs names = [ 'dpcl', 'mask_inference_l1', 'mask_inference_mse_loss', 'chimera', 'open_unmix', 'end_to_end', 'dual_path' ] config_has_batch_norm = ['open_unmix', 'dual_path'] configs = [ ml.networks.builders.build_recurrent_dpcl( n_features, 50, 1, True, 0.0, 20, ['sigmoid'], normalization_class='InstanceNorm'), ml.networks.builders.build_recurrent_mask_inference( n_features, 50, 1, True, 0.0, 2, ['softmax'], normalization_class='InstanceNorm'), ml.networks.builders.build_recurrent_mask_inference( n_features, 50, 1, True, 0.0, 2, ['softmax'], normalization_class='InstanceNorm'), ml.networks.builders.build_recurrent_chimera( n_features, 50, 1, True, 0.0, 20, ['sigmoid'], 2, ['softmax'], normalization_class='InstanceNorm'), ml.networks.builders.build_open_unmix_like( n_features, 50, 1, True, .4, 2, 1, add_embedding=True, embedding_size=20, embedding_activation=['sigmoid', 'unit_norm'], ), ml.networks.builders.build_recurrent_end_to_end( 256, 256, 64, 'sqrt_hann', 50, 2, True, 0.0, 2, 'sigmoid', num_audio_channels=1, mask_complex=False, rnn_type='lstm', mix_key='mix_audio', normalization_class='InstanceNorm'), ml.networks.builders.build_dual_path_recurrent_end_to_end( 64, 16, 8, 60, 30, 50, 2, True, 25, 2, 'sigmoid', ) ] loss_dictionaries = [ { 'DeepClusteringLoss': { 'weight': 1.0 } }, { 'L1Loss': { 'weight': 1.0 } }, { 'MSELoss': { 'weight': 1.0 } }, { 'DeepClusteringLoss': { 'weight': 0.2 }, 'PermutationInvariantLoss': { 'args': ['L1Loss'], 'weight': 0.8 } }, { 'DeepClusteringLoss': { 'weight': 0.2 }, 'PermutationInvariantLoss': { 'args': ['L1Loss'], 'weight': 0.8 } }, { 'SISDRLoss': { 'weight': 1.0, 'keys': { 'audio': 'estimates', 'source_audio': 'references' } } }, { 'SISDRLoss': { 'weight': 1.0, 'keys': { 'audio': 'estimates', 'source_audio': 'references' } } }, ] def append_keys_to_model(name, model): if name == 'end_to_end': model.output_keys.extend( ['audio', 'recurrent_stack', 'mask', 'estimates']) elif name == 'dual_path': model.output_keys.extend( ['audio', 'mixture_weights', 'dual_path', 'mask', 'estimates']) for name, config, loss_dictionary in zip(names, configs, loss_dictionaries): loss_closure = ml.train.closures.Closure(loss_dictionary) utils.seed(0, set_cudnn=True) model_grad = ml.SeparationModel(config, verbose=True).to(DEVICE) append_keys_to_model(name, model_grad) all_data = {} for data in dataset: for key in data: if torch.is_tensor(data[key]): data[key] = data[key].float().unsqueeze(0).contiguous().to( DEVICE) if key not in all_data: all_data[key] = data[key] else: all_data[key] = torch.cat([all_data[key], data[key]], dim=0) # do a forward pass in batched mode output_grad = model_grad(all_data) _loss = loss_closure.compute_loss(output_grad, all_data) # do a backward pass in batched mode _loss['loss'].backward() plt.figure(figsize=(10, 10)) utils.visualize_gradient_flow(model_grad.named_parameters()) plt.tight_layout() plt.savefig(f'tests/local/{name}:batch_gradient.png') utils.seed(0, set_cudnn=True) model_acc = ml.SeparationModel(config).to(DEVICE) append_keys_to_model(name, model_acc) for i, data in enumerate(dataset): for key in data: if torch.is_tensor(data[key]): data[key] = data[key].float().unsqueeze(0).contiguous().to( DEVICE) # do a forward pass on each item individually output_acc = model_acc(data) for key in output_acc: # make sure the forward pass in batch and forward pass individually match # if they don't, then items in a minibatch are talking to each other # somehow... _data_a = output_acc[key] _data_b = output_grad[key][i].unsqueeze(0) if name not in config_has_batch_norm: assert torch.allclose(_data_a, _data_b, atol=1e-3) _loss = loss_closure.compute_loss(output_acc, data) # do a backward pass on each item individually _loss['loss'] = _loss['loss'] / len(dataset) _loss['loss'].backward() plt.figure(figsize=(10, 10)) utils.visualize_gradient_flow(model_acc.named_parameters()) plt.tight_layout() plt.savefig(f'tests/local/{name}:accumulated_gradient.png') # make sure the gradients match between batched and accumulated gradients # if they don't, then the items in a batch are talking to each other in the loss for param1, param2 in zip(model_grad.parameters(), model_acc.parameters()): assert torch.allclose(param1, param2) if name not in config_has_batch_norm: if param1.requires_grad and param2.requires_grad: assert torch.allclose(param1.grad.mean(), param2.grad.mean(), atol=1e-3)