def main(conf): model_path = os.path.join(conf["exp_dir"], "best_model.pth") model = DPRNNTasNet.from_pretrained(model_path) # Handle device placement if conf["use_gpu"]: model.cuda() model_device = next(model.parameters()).device if conf['file_path'] == '': test_set = LibriMix( csv_dir=conf["test_dir"], task=conf["task"], sample_rate=conf["sample_rate"], n_src=conf["train_conf"]["masknet"]["n_src"], segment=None, ) # Uses all segment length
def main(conf): train_set = WhamDataset( conf["data"]["train_dir"], conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], segment=conf["data"]["segment"], nondefault_nsrc=conf["data"]["nondefault_nsrc"], ) val_set = WhamDataset( conf["data"]["valid_dir"], conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], nondefault_nsrc=conf["data"]["nondefault_nsrc"], ) train_loader = DataLoader( train_set, shuffle=True, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) val_loader = DataLoader( val_set, shuffle=False, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) # Update number of source values (It depends on the task) conf["masknet"].update({"n_src": train_set.n_src}) model = DPRNNTasNet(**conf["filterbank"], **conf["masknet"]) optimizer = make_optimizer(model.parameters(), **conf["optim"]) # Define scheduler scheduler = None if conf["training"]["half_lr"]: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf["main_args"]["exp_dir"] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, "conf.yml") with open(conf_path, "w") as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") system = System( model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf, ) # Define callbacks checkpoint_dir = os.path.join(exp_dir, "checkpoints/") checkpoint = ModelCheckpoint(checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True) early_stopping = False if conf["training"]["early_stop"]: early_stopping = EarlyStopping(monitor="val_loss", patience=30, verbose=True) # Don't ask GPU if they are not available. gpus = -1 if torch.cuda.is_available() else None trainer = pl.Trainer( max_epochs=conf["training"]["epochs"], checkpoint_callback=checkpoint, early_stop_callback=early_stopping, default_root_dir=exp_dir, gpus=gpus, distributed_backend="ddp", gradient_clip_val=conf["training"]["gradient_clipping"], ) trainer.fit(system) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) state_dict = torch.load(checkpoint.best_model_path) system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
def main(conf): model_path = os.path.join(conf['exp_dir'], 'best_model.pth') model = DPRNNTasNet.from_pretrained(model_path) # Handle device placement if conf['use_gpu']: model.cuda() model_device = next(model.parameters()).device test_set = WhamDataset(conf['test_dir'], conf['task'], sample_rate=conf['sample_rate'], nondefault_nsrc=model.masker.n_src, segment=None) # Uses all segment length # Used to reorder sources only loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') # Randomly choose the indexes of sentences to save. ex_save_dir = os.path.join(conf['exp_dir'], 'examples/') if conf['n_save_ex'] == -1: conf['n_save_ex'] = len(test_set) save_idx = random.sample(range(len(test_set)), conf['n_save_ex']) series_list = [] torch.no_grad().__enter__() for idx in tqdm(range(len(test_set))): # Forward the network on the mixture. mix, sources = tensors_to_device(test_set[idx], device=model_device) est_sources = model(mix[None, None]) loss, reordered_sources = loss_func(est_sources, sources[None], return_est=True) mix_np = mix[None].cpu().data.numpy() sources_np = sources.squeeze().cpu().data.numpy() est_sources_np = reordered_sources.squeeze().cpu().data.numpy() utt_metrics = get_metrics(mix_np, sources_np, est_sources_np, sample_rate=conf['sample_rate']) utt_metrics['mix_path'] = test_set.mix[idx][0] series_list.append(pd.Series(utt_metrics)) # Save some examples in a folder. Wav files and metrics as text. if idx in save_idx: local_save_dir = os.path.join(ex_save_dir, 'ex_{}/'.format(idx)) os.makedirs(local_save_dir, exist_ok=True) sf.write(local_save_dir + "mixture.wav", mix_np[0], conf['sample_rate']) # Loop over the sources and estimates for src_idx, src in enumerate(sources_np): sf.write(local_save_dir + "s{}.wav".format(src_idx+1), src, conf['sample_rate']) for src_idx, est_src in enumerate(est_sources_np): sf.write(local_save_dir + "s{}_estimate.wav".format(src_idx+1), est_src, conf['sample_rate']) # Write local metrics to the example folder. with open(local_save_dir + 'metrics.json', 'w') as f: json.dump(utt_metrics, f, indent=0) # Save all metrics to the experiment folder. all_metrics_df = pd.DataFrame(series_list) all_metrics_df.to_csv(os.path.join(conf['exp_dir'], 'all_metrics.csv')) # Print and save summary metrics final_results = {} for metric_name in compute_metrics: input_metric_name = 'input_' + metric_name ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] final_results[metric_name] = all_metrics_df[metric_name].mean() final_results[metric_name + '_imp'] = ldf.mean() print('Overall metrics :') pprint(final_results) with open(os.path.join(conf['exp_dir'], 'final_metrics.json'), 'w') as f: json.dump(final_results, f, indent=0) model_dict = torch.load(model_path, map_location='cpu') publishable = save_publishable( os.path.join(conf['exp_dir'], 'publish_dir'), model_dict, metrics=final_results, train_conf=train_conf )
def main(conf): train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], segment=conf['data']['segment'], nondefault_nsrc=conf['data']['nondefault_nsrc']) val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) train_loader = DataLoader(train_set, shuffle=True, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) val_loader = DataLoader(val_set, shuffle=False, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) # Update number of source values (It depends on the task) conf['masknet'].update({'n_src': train_set.n_src}) model = DPRNNTasNet(**conf['filterbank'], **conf['masknet']) optimizer = make_optimizer(model.parameters(), **conf['optim']) # Define scheduler scheduler = None if conf['training']['half_lr']: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf['main_args']['exp_dir'] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, 'conf.yml') with open(conf_path, 'w') as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') system = System(model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf) # Define callbacks checkpoint_dir = os.path.join(exp_dir, 'checkpoints/') checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', mode='min', save_top_k=5, verbose=1) early_stopping = False if conf['training']['early_stop']: early_stopping = EarlyStopping(monitor='val_loss', patience=30, verbose=1) # Don't ask GPU if they are not available. gpus = -1 if torch.cuda.is_available() else None trainer = pl.Trainer( max_nb_epochs=conf['training']['epochs'], checkpoint_callback=checkpoint, early_stop_callback=early_stopping, default_save_path=exp_dir, gpus=gpus, distributed_backend='ddp', gradient_clip_val=conf['training']["gradient_clipping"]) trainer.fit(system) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) # Save best model (next PL version will make this easier) best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0] state_dict = torch.load(best_path) system.load_state_dict(state_dict=state_dict['state_dict']) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))
from asteroid import DPRNNTasNet model = DPRNNTasNet.from_pretrained('mpariente/DPRNNTasNet_WHAM!_sepclean') def main(args): convert(args[0]) def convert(filename): model.separate('./audio-data/' + filename) if __name__ == "__main__": import sys main(sys.argv[1:])
def main(conf): compute_metrics = update_compute_metrics(conf["compute_wer"], COMPUTE_METRICS) anno_df = pd.read_csv( Path(conf["test_dir"]).parent.parent.parent / "test_annotations.csv") wer_tracker = (MockWERTracker() if not conf["compute_wer"] else WERTracker( ASR_MODEL_PATH, anno_df)) model_path = os.path.join(conf["exp_dir"], "best_model.pth") model = DPRNNTasNet.from_pretrained(model_path) # Handle device placement if conf["use_gpu"]: model.cuda() model_device = next(model.parameters()).device test_set = LibriMix( csv_dir=conf["test_dir"], task=conf["task"], sample_rate=conf["sample_rate"], n_src=conf["train_conf"]["data"]["n_src"], segment=None, return_id=True, ) # Uses all segment length # Used to reorder sources only loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") # Randomly choose the indexes of sentences to save. eval_save_dir = os.path.join(conf["exp_dir"], conf["out_dir"]) ex_save_dir = os.path.join(eval_save_dir, "examples/") if conf["n_save_ex"] == -1: conf["n_save_ex"] = len(test_set) save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) series_list = [] torch.no_grad().__enter__() for idx in tqdm(range(len(test_set))): # Forward the network on the mixture. mix, sources, ids = test_set[idx] mix, sources = tensors_to_device([mix, sources], device=model_device) est_sources = model(mix.unsqueeze(0)) loss, reordered_sources = loss_func(est_sources, sources[None], return_est=True) mix_np = mix.cpu().data.numpy() sources_np = sources.cpu().data.numpy() est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy() # For each utterance, we get a dictionary with the mixture path, # the input and output metrics utt_metrics = get_metrics( mix_np, sources_np, est_sources_np, sample_rate=conf["sample_rate"], metrics_list=COMPUTE_METRICS, ) utt_metrics["mix_path"] = test_set.mixture_path est_sources_np_normalized = normalize_estimates(est_sources_np, mix_np) utt_metrics.update(**wer_tracker( mix=mix_np, clean=sources_np, estimate=est_sources_np_normalized, wav_id=ids, sample_rate=conf["sample_rate"], )) series_list.append(pd.Series(utt_metrics)) # Save some examples in a folder. Wav files and metrics as text. if idx in save_idx: local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) os.makedirs(local_save_dir, exist_ok=True) sf.write(local_save_dir + "mixture.wav", mix_np, conf["sample_rate"]) # Loop over the sources and estimates for src_idx, src in enumerate(sources_np): sf.write(local_save_dir + "s{}.wav".format(src_idx), src, conf["sample_rate"]) for src_idx, est_src in enumerate(est_sources_np_normalized): sf.write( local_save_dir + "s{}_estimate.wav".format(src_idx), est_src, conf["sample_rate"], ) # Write local metrics to the example folder. with open(local_save_dir + "metrics.json", "w") as f: json.dump(utt_metrics, f, indent=0) # Save all metrics to the experiment folder. all_metrics_df = pd.DataFrame(series_list) all_metrics_df.to_csv(os.path.join(eval_save_dir, "all_metrics.csv")) # Print and save summary metrics final_results = {} for metric_name in compute_metrics: input_metric_name = "input_" + metric_name ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] final_results[metric_name] = all_metrics_df[metric_name].mean() final_results[metric_name + "_imp"] = ldf.mean() print("Overall metrics :") pprint(final_results) if conf["compute_wer"]: print("\nWER report") wer_card = wer_tracker.final_report_as_markdown() print(wer_card) # Save the report with open(os.path.join(eval_save_dir, "final_wer.md"), "w") as f: f.write(wer_card) with open(os.path.join(eval_save_dir, "final_metrics.json"), "w") as f: json.dump(final_results, f, indent=0) model_dict = torch.load(model_path, map_location="cpu") os.makedirs(os.path.join(conf["exp_dir"], "publish_dir"), exist_ok=True) publishable = save_publishable( os.path.join(conf["exp_dir"], "publish_dir"), model_dict, metrics=final_results, train_conf=train_conf, )
def main(conf): model_path = os.path.join(conf["exp_dir"], "best_model.pth") model = DPRNNTasNet.from_pretrained(model_path) # Handle device placement if conf["use_gpu"]: model.cuda() model_device = next(model.parameters()).device test_set = SourceFolderDataset( os.path.join(conf["exp_dir"], "json/"), conf["wav_dir"], conf["n_src"], conf["sample_rate"], conf["batch_size"], train = False ) # Uses all segment length # Used to reorder sources only loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") # Randomly choose the indexes of sentences to save. ex_save_dir = os.path.join(conf["exp_dir"], "examples/") if conf["n_save_ex"] == -1: conf["n_save_ex"] = len(test_set) save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) series_list = [] torch.no_grad().__enter__() for idx in tqdm(range(len(test_set))): # Forward the network on the mixture. mix, sources = tensors_to_device(test_set[idx], device=model_device) mix = mix.unsqueeze(0) sources = sources.unsqueeze(0) est_sources = model(mix) #print(test_set[idx]) #print(est_sources.shape, sources.shape, mix.shape, len(test_set)) loss, reordered_sources = loss_func(est_sources, sources, return_est=True) #mix_np = mix.squeeze(0).cpu().data.numpy() mix_np = mix.cpu().data.numpy() sources_np = sources.squeeze(0).cpu().data.numpy() est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy() utt_metrics = get_metrics( mix_np, sources_np, est_sources_np, sample_rate=conf["sample_rate"], metrics_list=compute_metrics, ) utt_metrics["mix_path"] = test_set.mix[idx][0] series_list.append(pd.Series(utt_metrics)) # Save some examples in a folder. Wav files and metrics as text. if idx in save_idx: local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) os.makedirs(local_save_dir, exist_ok=True) sf.write(local_save_dir + "mixture.wav", np.swapaxes(mix_np,0,1), conf["sample_rate"]) # Loop over the sources and estimates for src_idx, src in enumerate(sources_np): sf.write(local_save_dir + "s{}.wav".format(src_idx + 1), src, conf["sample_rate"]) for src_idx, est_src in enumerate(est_sources_np): est_src *= np.max(np.abs(mix_np)) / np.max(np.abs(est_src)) sf.write( local_save_dir + "s{}_estimate.wav".format(src_idx + 1), est_src, conf["sample_rate"], ) # Write local metrics to the example folder. with open(local_save_dir + "metrics.json", "w") as f: json.dump(utt_metrics, f, indent=0) # Save all metrics to the experiment folder. all_metrics_df = pd.DataFrame(series_list) all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv")) # Print and save summary metrics final_results = {} for metric_name in compute_metrics: input_metric_name = "input_" + metric_name ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] final_results[metric_name] = all_metrics_df[metric_name].mean() final_results[metric_name + "_imp"] = ldf.mean() print("Overall metrics :") pprint(final_results) with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f: json.dump(final_results, f, indent=0) model_dict = torch.load(model_path, map_location="cpu") os.makedirs(os.path.join(conf["exp_dir"], "publish_dir"), exist_ok=True) publishable = save_publishable( os.path.join(conf["exp_dir"], "publish_dir"), model_dict, metrics=final_results, train_conf=train_conf, )
#import pdb; pdb.set_trace() total_df.sort_values(['SI-SDR', 'PESQ', 'STOI'], inplace=True) total_df = total_df.round({'SI-SDR': 3, 'PESQ': 3, 'STOI': 3}) print(total_df) return total_df models = { 'input': None, 'baseline': RegressionFCNN.from_pretrained('models/baseline_model_v1.pt'), 'vae': VAE.from_pretrained('/jmain01/home/JAD007/txk02/aaa18-txk02/workspace/models/VAE.pt'), 'auto_encoder': VAE.from_pretrained('/jmain01/home/JAD007/txk02/aaa18-txk02/workspace/models/AutoEncoder.pt'), 'waveunet_v1': WaveUNet.from_pretrained('models/waveunet_model_adapt.pt'), 'dcunet_20': DCUNet.from_pretrained('models/dcunet_20_random_v2.pt'), 'dccrn': DCCRNet.from_pretrained('models/dccrn_random_v1.pt'), 'smolnet': SMoLnet.from_pretrained('models/SMoLnet.pt'), 'dprnn': DPRNNTasNet.from_pretrained('models/dprnn_model.pt'), 'conv_tasnet': ConvTasNet.from_pretrained('models/convtasnet_model.pt'), 'dptnet': DPTNet.from_pretrained('models/dptnet_model.pt'), 'demucs': Demucs.from_pretrained('models/Demucs.pt'), } def eval_all_and_plot(models, test_set, directory, plot_name): results_dfs = {} for model_name, model in models.items(): print(f'Evaluating {model_labels[model_name]}') csv_path = f'/jmain01/home/JAD007/txk02/aaa18-txk02/DRONE_project/asteroid/notebooks/{directory}/{model_name}.csv' if os.path.isfile(csv_path): print('Results already available') df = pd.read_csv(csv_path)
def main(conf): # train_set = WhamDataset( # conf["data"]["train_dir"], # conf["data"]["task"], # sample_rate=conf["data"]["sample_rate"], # segment=conf["data"]["segment"], # nondefault_nsrc=conf["data"]["nondefault_nsrc"], # ) # val_set = WhamDataset( # conf["data"]["valid_dir"], # conf["data"]["task"], # sample_rate=conf["data"]["sample_rate"], # nondefault_nsrc=conf["data"]["nondefault_nsrc"], # ) train_set = LibriMix( csv_dir=conf["data"]["train_dir"], task=conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], n_src=conf["masknet"]["n_src"], segment=conf["data"]["segment"], ) val_set = LibriMix( csv_dir=conf["data"]["valid_dir"], task=conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], n_src=conf["masknet"]["n_src"], segment=conf["data"]["segment"], ) train_loader = DataLoader( train_set, shuffle=True, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) val_loader = DataLoader( val_set, shuffle=False, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) # Update number of source values (It depends on the task) # TODO: redundant conf["masknet"].update({"n_src": train_set.n_src}) model = DPRNNTasNet(**conf["filterbank"], **conf["masknet"], sample_rate=conf['data']['sample_rate']) # from torchsummary import summary # model.cuda() # summary(model, (24000,)) # import pdb # pdb.set_trace() optimizer = make_optimizer(model.parameters(), **conf["optim"]) # Define scheduler scheduler = None if conf["training"]["half_lr"]: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf["main_args"]["exp_dir"] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, "conf.yml") with open(conf_path, "w") as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") system = System( model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf, ) # Define callbacks callbacks = [] checkpoint_dir = os.path.join(exp_dir, "checkpoints/") checkpoint = ModelCheckpoint(checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True) callbacks.append(checkpoint) if conf["training"]["early_stop"]: callbacks.append( EarlyStopping(monitor="val_loss", mode="min", patience=30, verbose=True)) # Don't ask GPU if they are not available. gpus = -1 if torch.cuda.is_available() else None distributed_backend = "ddp" if torch.cuda.is_available() else None if conf["training"]["cont"]: from glob import glob ckpts = glob('%s/*.ckpt' % checkpoint_dir) ckpts.sort() latest_ckpt = ckpts[-1] trainer = pl.Trainer( max_epochs=conf["training"]["epochs"], callbacks=callbacks, default_root_dir=exp_dir, gpus=gpus, distributed_backend=distributed_backend, limit_train_batches=1.0, # Useful for fast experiment gradient_clip_val=conf["training"]["gradient_clipping"], resume_from_checkpoint=latest_ckpt) else: trainer = pl.Trainer( max_epochs=conf["training"]["epochs"], callbacks=callbacks, default_root_dir=exp_dir, gpus=gpus, distributed_backend=distributed_backend, limit_train_batches=1.0, # Useful for fast experiment gradient_clip_val=conf["training"]["gradient_clipping"], ) trainer.fit(system) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) # Save best model (next PL version will make this easier) # best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0] # state_dict = torch.load(best_path) state_dict = torch.load(checkpoint.best_model_path) # state_dict = torch.load('exp/train_dprnn_130d5f9a/checkpoints/epoch=154.ckpt') system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
def main(conf): model_path = os.path.join(conf["exp_dir"], "best_model.pth") model = DPRNNTasNet.from_pretrained(model_path)
drop_last=True) val_loader = DataLoader(timit_val, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, drop_last=True) # some random parameters, does it look sensible? LR = 1e-3 REDUCE_LR_PATIENCE = 5 EARLY_STOP_PATIENCE = 20 MAX_EPOCHS = 300 # the model here should be constructed in the script accordingly to the passed config (including the model type) # most of the models accept `sample_rate` parameter for encoders, which is important (default is 16000, override) #model = DCUNet("DCUNet-20", fix_length_mode="trim", sample_rate=SAMPLE_RATE) model = DPRNNTasNet(n_src=1) from pytorch_lightning.callbacks import ModelCheckpoint checkpoint = ModelCheckpoint(filename='{epoch:02d}-{val_loss:.2f}', monitor="val_loss", mode="min", save_top_k=5, verbose=True) optimizer = optim.Adam(model.parameters(), lr=LR) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=REDUCE_LR_PATIENCE) early_stopping = EarlyStopping(monitor='val_loss', patience=EARLY_STOP_PATIENCE) # Probably we also need to subclass `System`, in order to log the target metrics on the validation set (PESQ/STOI)