def test_dcunet(): n_fft = 1024 _, istft = make_enc_dec("stft", n_filters=n_fft, kernel_size=1024, stride=256, sample_rate=16000) input_samples = istft(torch.zeros((n_fft + 2, 17))).shape[0] _default_test_model(DCUNet("DCUNet-10"), input_samples=input_samples) _default_test_model(DCUNet("DCUNet-16"), input_samples=input_samples) _default_test_model(DCUNet("DCUNet-20"), input_samples=input_samples) _default_test_model(DCUNet("Large-DCUNet-20"), input_samples=input_samples) _default_test_model(DCUNet("DCUNet-10", n_src=2), input_samples=input_samples) # DCUMaskNet should fail with wrong freqency dimensions DCUNet("mini").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) with pytest.raises(TypeError): DCUNet("mini").masker(torch.zeros((1, 42, 17), dtype=torch.complex64)) # DCUMaskNet should fail with wrong time dimensions if fix_length_mode is not used DCUNet("mini", fix_length_mode="pad").masker( torch.zeros((1, 9, 17), dtype=torch.complex64)) DCUNet("mini", fix_length_mode="trim").masker( torch.zeros((1, 9, 17), dtype=torch.complex64)) with pytest.raises(TypeError): DCUNet("mini").masker(torch.zeros((1, 9, 16), dtype=torch.complex64))
def test_dcunet_model(test_shape: Tuple, matching_samples): n_samples = 5010 device = get_default_device() model = DCUNet(architecture="mini", fix_length_mode="pad").eval().to(device) # Random input uniformly distributed in [-1, 1] inputs = torch.rand(1, n_samples, device=device) traced = torch.jit.trace(model, (inputs, )) test_data = torch.rand(*test_shape, matching_samples, device=device) assert_consistency(model=model, traced=traced, tensor=test_data.to(device))
def test_dcunet(): _, istft = make_enc_dec("stft", 512, 512) input_samples = istft(torch.zeros((514, 17))).shape[0] _default_test_model(DCUNet("DCUNet-10"), input_samples=input_samples) _default_test_model(DCUNet("DCUNet-10", n_src=2), input_samples=input_samples) # DCUMaskNet should fail with wrong freqency dimensions DCUNet("mini").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) with pytest.raises(TypeError): DCUNet("mini").masker(torch.zeros((1, 42, 17), dtype=torch.complex64)) # DCUMaskNet should fail with wrong time dimensions if fix_length_mode is not used DCUNet("mini", fix_length_mode="pad").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) DCUNet("mini", fix_length_mode="trim").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) with pytest.raises(TypeError): DCUNet("mini").masker(torch.zeros((1, 9, 16), dtype=torch.complex64))
def test_dcunet(): _, istft = make_enc_dec("stft", 512, 512) _default_test_model(DCUNet("DCUNet-10"), input_samples=istft(torch.zeros((514, 17))).shape[0])
def main(conf): train_set = LibriMix( csv_dir=conf["data"]["train_dir"], task=conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], n_src=conf["data"]["n_src"], segment=conf["data"]["segment"], ) val_set = LibriMix( csv_dir=conf["data"]["valid_dir"], task=conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], n_src=conf["data"]["n_src"], segment=conf["data"]["segment"], ) train_loader = DataLoader( train_set, shuffle=True, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) val_loader = DataLoader( val_set, shuffle=False, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) conf["masknet"].update({"n_src": conf["data"]["n_src"]}) model = DCUNet(**conf["filterbank"], **conf["masknet"], sample_rate=conf["data"]["sample_rate"]) optimizer = make_optimizer(model.parameters(), **conf["optim"]) # Define scheduler scheduler = None if conf["training"]["half_lr"]: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf["main_args"]["exp_dir"] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, "conf.yml") with open(conf_path, "w") as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") system = System( model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf, ) # Define callbacks callbacks = [] checkpoint_dir = os.path.join(exp_dir, "checkpoints/") checkpoint = ModelCheckpoint(checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True) callbacks.append(checkpoint) if conf["training"]["early_stop"]: callbacks.append( EarlyStopping(monitor="val_loss", mode="min", patience=30, verbose=True)) # Don't ask GPU if they are not available. gpus = -1 if torch.cuda.is_available() else None distributed_backend = "ddp" if torch.cuda.is_available() else None trainer = pl.Trainer( max_epochs=conf["training"]["epochs"], callbacks=callbacks, default_root_dir=exp_dir, gpus=gpus, distributed_backend=distributed_backend, limit_train_batches=1.0, # Useful for fast experiment gradient_clip_val=5.0, ) trainer.fit(system) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) state_dict = torch.load(checkpoint.best_model_path) system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))