def main(conf): train_set = WhamDataset( conf["data"]["train_dir"], conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], segment=conf["data"]["segment"], nondefault_nsrc=conf["data"]["nondefault_nsrc"], ) val_set = WhamDataset( conf["data"]["valid_dir"], conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], nondefault_nsrc=conf["data"]["nondefault_nsrc"], ) train_loader = DataLoader( train_set, shuffle=True, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) val_loader = DataLoader( val_set, shuffle=False, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) # Update number of source values (It depends on the task) conf["masknet"].update({"n_src": train_set.n_src}) model = DPTNet(**conf["filterbank"], **conf["masknet"]) optimizer = make_optimizer(model.parameters(), **conf["optim"]) from asteroid.engine.schedulers import DPTNetScheduler schedulers = { "scheduler": DPTNetScheduler(optimizer, len(train_loader) // conf["training"]["batch_size"], 64), "interval": "step", } # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf["main_args"]["exp_dir"] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, "conf.yml") with open(conf_path, "w") as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") system = System( model=model, loss_func=loss_func, optimizer=optimizer, scheduler=schedulers, train_loader=train_loader, val_loader=val_loader, config=conf, ) # Define callbacks checkpoint_dir = os.path.join(exp_dir, "checkpoints/") checkpoint = ModelCheckpoint(checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True) early_stopping = False if conf["training"]["early_stop"]: early_stopping = EarlyStopping(monitor="val_loss", patience=30, verbose=True) # Don't ask GPU if they are not available. gpus = -1 if torch.cuda.is_available() else None trainer = pl.Trainer( max_epochs=conf["training"]["epochs"], checkpoint_callback=checkpoint, early_stop_callback=early_stopping, default_root_dir=exp_dir, gpus=gpus, distributed_backend="ddp", gradient_clip_val=conf["training"]["gradient_clipping"], ) trainer.fit(system) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) state_dict = torch.load(checkpoint.best_model_path) system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
REDUCE_LR_PATIENCE = 3 EARLY_STOP_PATIENCE = 10 MAX_EPOCHS = 300 # the model here should be constructed in the script accordingly to the passed config (including the model type) # most of the models accept `sample_rate` parameter for encoders, which is important (default is 16000, override) #model = DCUNet("DCUNet-20", fix_length_mode="trim", sample_rate=SAMPLE_RATE) model = DPTNet(n_src=1) from pytorch_lightning.callbacks import ModelCheckpoint checkpoint = ModelCheckpoint(filename='{epoch:02d}-{val_loss:.2f}', monitor="val_loss", mode="min", save_top_k=5, verbose=True) optimizer = optim.Adam(model.parameters(), lr=LR) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=REDUCE_LR_PATIENCE) early_stopping = EarlyStopping(monitor='val_loss', patience=EARLY_STOP_PATIENCE) # Probably we also need to subclass `System`, in order to log the target metrics on the validation set (PESQ/STOI) system = System(model, optimizer, sisdr_loss_wrapper, train_loader, train_loader, scheduler) # log dir and model name are also part of the config, of course LOG_DIR = 'logs' logger = pl_loggers.TensorBoardLogger(LOG_DIR, name='TIMIT-drones-DPTNet-random', version=1)