Exemple #1
0
def main(conf):
    # from asteroid.data.toy_data import WavSet
    # train_set = WavSet(n_ex=1000, n_src=2, ex_len=32000)
    # val_set = WavSet(n_ex=1000, n_src=2, ex_len=32000)
    # Define data pipeline
    train_set = WhamDataset(conf['data']['train_dir'],
                            conf['data']['task'],
                            sample_rate=conf['data']['sample_rate'],
                            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'],
                          conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'])
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'])
    conf['masknet'].update({'n_src': train_set.n_src})

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    # loss_class = PITLossContainer(pairwise_neg_sisdr, n_src=train_set.n_src)
    # Checkpointing callback can monitor any quantity which is returned by
    # validation step, defaults to val_loss here (see System).
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_best_only=False)
    # New PL version will come the 7th of december / will have save_top_k
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    config=conf)
    trainer = pl.Trainer(max_nb_epochs=conf['training']['epochs'],
                         checkpoint_callback=checkpoint,
                         default_save_path=exp_dir,
                         gpus=conf['main_args']['gpus'],
                         distributed_backend='dp')
    trainer.fit(system)
Exemple #2
0
def main(conf):
    exp_dir = conf['main_args']['exp_dir']
    # Define Dataloader
    train_loader, val_loader = make_dataloaders(**conf['data'],
                                                **conf['training'])
    conf['masknet'].update({'n_src': conf['data']['n_src']})
    # Define model, optimizer + scheduler
    model, optimizer = make_model_and_optimizer(conf)
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,
                                      patience=5)

    # Save config
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define loss function
    loss_func = ChimeraLoss(alpha=conf['training']['loss_alpha'])
    # Put together in System
    system = ChimeraSystem(model=model, loss_func=loss_func,
                           optimizer=optimizer, train_loader=train_loader,
                           val_loader=val_loader, scheduler=scheduler,
                           config=conf)

    # Callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
                                 mode='min', save_top_k=5, verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss', patience=10,
                                       verbose=1)
    gpus=-1
    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print('No available GPU were found, set gpus to None')
        gpus = None

    # Train model
    trainer = pl.Trainer(max_nb_epochs=conf['training']['epochs'],
                         checkpoint_callback=checkpoint,
                         early_stop_callback=early_stopping,
                         default_save_path=exp_dir,
                         gpus=gpus,
                         distributed_backend='dp',
                         train_percent_check=1.0,  # Useful for fast experiment
                         gradient_clip_val=200,)
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)
    # Save last model for convenience
    torch.save(system.model.state_dict(),
               os.path.join(exp_dir, 'checkpoints/final.pth'))
Exemple #3
0
def train_model_part(conf, train_part='filterbank', pretrained_filterbank=None):
    train_loader, val_loader = get_data_loaders(conf, train_part=train_part)

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(
        conf, model_part=train_part, pretrained_filterbank=pretrained_filterbank
    )
    # Define scheduler
    scheduler = None
    if conf[train_part + '_training'][train_part[0] + '_half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir, checkpoint_dir = get_encoded_paths(conf, train_part)
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(PairwiseNegSDR('sisdr', zero_mean=False),
                               pit_from='pw_mtx')
    system = SystemTwoStep(model=model, loss_func=loss_func,
                           optimizer=optimizer, train_loader=train_loader,
                           val_loader=val_loader, scheduler=scheduler,
                           config=conf, module=train_part)

    # Define callbacks
    checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
                                 mode='min', save_top_k=1, verbose=1)
    early_stopping = False
    if conf[train_part + '_training'][train_part[0] + '_early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss', patience=10,
                                       verbose=1)
    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print('No available GPU were found, set gpus to None')
        conf['main_args']['gpus'] = None

    trainer = pl.Trainer(
        max_nb_epochs=conf[train_part + '_training'][train_part[0] + '_epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=conf['main_args']['gpus'],
        distributed_backend='dp',
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.)
    trainer.fit(system)

    with open(os.path.join(checkpoint_dir, "best_k_models.json"), "w") as file:
        json.dump(checkpoint.best_k_models, file, indent=0)
Exemple #4
0
def get_model(conf):
    # Create the model from recipe-local function
    model, _ = make_model_and_optimizer(conf['train_conf'])
    # Load state dict
    last_model_path = os.path.join(conf['exp_dir'], 'final.pth')
    model.load_state_dict(torch.load(last_model_path, map_location='cpu'))
    # Handle device placement
    if conf['use_gpu']:
        model.cuda()
    model.eval()
    return model
Exemple #5
0
def main(conf):
    train_set = WSJ2mixDataset(conf['data']['tr_wav_len_list'],
                               conf['data']['wav_base_path'] + '/tr',
                               sample_rate=conf['data']['sample_rate'])
    val_set = WSJ2mixDataset(conf['data']['cv_wav_len_list'],
                             conf['data']['wav_base_path'] + '/cv',
                             sample_rate=conf['data']['sample_rate'])
    train_sampler = BucketingSampler(train_set,
                                     batch_size=conf['data']['batch_size'])
    valid_sampler = BucketingSampler(val_set,
                                     batch_size=conf['data']['batch_size'])

    train_loader = DataLoader(train_set,
                              batch_sampler=train_sampler,
                              collate_fn=collate_fn,
                              num_workers=conf['data']['num_workers'])
    val_loader = DataLoader(val_set,
                            batch_sampler=valid_sampler,
                            collate_fn=collate_fn,
                            num_workers=conf['data']['num_workers'])
    model, optimizer = make_model_and_optimizer(conf)
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)
    loss_func = handle_multiple_loss
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min')
    system = DcSystem(model=model,
                      loss_func=loss_func,
                      optimizer=optimizer,
                      train_loader=train_loader,
                      val_loader=val_loader,
                      config=conf)
    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print('No available GPU were found, set gpus to None')
        conf['main_args']['gpus'] = None
    trainer = pl.Trainer(
        max_nb_epochs=conf['training']['epochs'],
        checkpoint_callback=checkpoint,
        default_save_path=exp_dir,
        gpus=conf['main_args']['gpus'],
        distributed_backend='dp',
        train_percent_check=1.0  # Useful for fast experiment
    )
    trainer.fit(system)
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(checkpoint.best_k_models, f, indent=0)
Exemple #6
0
def main(conf):
    config = ParamConfig(
        conf["training"]["batch_size"],
        conf["training"]["epochs"],
        conf["training"]["num_workers"],
        cuda=True,
        use_half=False,
        learning_rate=conf["optim"]["lr"],
    )

    dataset = AVSpeechDataset(
        Path("data/train.csv"), Path(EMBED_DIR), conf["main_args"]["n_src"]
    )
    val_dataset = AVSpeechDataset(
        Path("data/val.csv"), Path(EMBED_DIR), conf["main_args"]["n_src"]
    )

    model, optimizer = make_model_and_optimizer(conf)
    print(
        f"AVFusion has {sum(np.prod(i.shape) for i in model.parameters()):,} parameters"
    )

    criterion = DiscriminativeLoss()

    model_path = Path(conf["main_args"]["exp_dir"]) / "checkpoints" / "best_full.pth"
    if model_path.is_file():
        print("Loading saved model...")
        resume = model_path.as_posix()
    else:
        resume = None

    if torch.cuda.device_count() > 1:
        print(f"Multiple GPUs available")
        device_ids = (
            list(map(int, conf["main_args"]["gpus"].split(",")))
            if conf["main_args"]["gpus"] != "-1"
            else None
        )
        model = torch.nn.DataParallel(model, device_ids=device_ids)

    train(
        model,
        dataset,
        optimizer,
        criterion,
        config,
        val_dataset=val_dataset,
        resume=resume,
        logdir=conf["main_args"]["exp_dir"],
    )
Exemple #7
0
def main(conf):
    config = ParamConfig(
        conf["training"]["batch_size"],
        conf["training"]["epochs"],
        conf["training"]["num_workers"],
        cuda=True,
        use_half=False,
        learning_rate=conf["optim"]["lr"],
    )

    dataset = AVSpeechDataset(Path("data/train.csv"), Path(EMBED_DIR),
                              conf["main_args"]["n_src"])
    val_dataset = AVSpeechDataset(Path("data/val.csv"), Path(EMBED_DIR),
                                  conf["main_args"]["n_src"])

    model, optimizer = make_model_and_optimizer(conf)
    print(
        f"AVFusion has {sum(np.prod(i.shape) for i in model.parameters()):,} parameters"
    )

    criterion = DiscriminativeLoss()

    model_path = Path(
        conf["main_args"]["exp_dir"]) / "checkpoints" / "best_full.pth"
    if model_path.is_file():
        print("Loading saved model...")
        resume = model_path.as_posix()
    else:
        resume = None

    train(
        model,
        dataset,
        optimizer,
        criterion,
        config,
        val_dataset=val_dataset,
        resume=resume,
        logdir=conf["main_args"]["exp_dir"],
    )
Exemple #8
0
def main(conf):
    train_set = WhamRDataset(
        conf["data"]["train_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )
    val_set = WhamRDataset(
        conf["data"]["valid_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    # Update number of source values (It depends on the task)
    conf["masknet"].update({"n_src": train_set.n_src})

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)
    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss",
                                       patience=30,
                                       verbose=True)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="dp",
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)
Exemple #9
0
def main(conf):
    total_set = DNSDataset(conf['data']['json_dir'])
    train_len = int(len(total_set) * (1 - conf['data']['val_prop']))
    val_len = len(total_set) - train_len
    train_set, val_set = random_split(total_set, [train_len, val_len])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = partial(distance, is_complex=conf['main_args']['is_complex'])
    system = SimpleSystem(model=model,
                          loss_func=loss_func,
                          optimizer=optimizer,
                          train_loader=train_loader,
                          val_loader=val_loader,
                          scheduler=scheduler,
                          config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_top_k=5,
                                 verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss',
                                       patience=30,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_nb_epochs=conf['training']['epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=gpus,
        distributed_backend='dp',
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)
Exemple #10
0
def train_model_part(conf,
                     train_part="filterbank",
                     pretrained_filterbank=None):
    train_loader, val_loader = get_data_loaders(conf, train_part=train_part)

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(
        conf,
        model_part=train_part,
        pretrained_filterbank=pretrained_filterbank)
    # Define scheduler
    scheduler = None
    if conf[train_part + "_training"][train_part[0] + "_half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir, checkpoint_dir = get_encoded_paths(conf, train_part)
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(PairwiseNegSDR("sisdr", zero_mean=False),
                               pit_from="pw_mtx")
    system = SystemTwoStep(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
        module=train_part,
    )

    # Define callbacks
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=1,
                                 verbose=True)
    early_stopping = False
    if conf[train_part + "_training"][train_part[0] + "_early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss",
                                       patience=30,
                                       verbose=True)
    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf[train_part + "_training"][train_part[0] + "_epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="dp",
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )
    trainer.fit(system)

    with open(os.path.join(checkpoint_dir, "best_k_models.json"), "w") as file:
        json.dump(checkpoint.best_k_models, file, indent=0)
Exemple #11
0
def main(conf):
    total_set = DNSDataset(conf["data"]["json_dir"])
    train_len = int(len(total_set) * (1 - conf["data"]["val_prop"]))
    val_len = len(total_set) - train_len
    train_set, val_set = random_split(total_set, [train_len, val_len])

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)
    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = partial(distance, is_complex=conf["main_args"]["is_complex"])
    system = SimpleSystem(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(
            EarlyStopping(monitor="val_loss",
                          mode="min",
                          patience=30,
                          verbose=True))

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    distributed_backend = "ddp" if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        callbacks=callbacks,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend=distributed_backend,
        limit_train_batches=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)
Exemple #12
0
def main(conf):
    # FIXME : Make a function to return loaders, which take conf['data'] as inp.
    # Where is the mode min of max?
    train_set = WhamDataset(conf['data']['train_dir'],
                            conf['data']['task'],
                            sample_rate=conf['data']['sample_rate'],
                            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'],
                          conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['data']['batch_size'],
                              num_workers=conf['data']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=True,
                            batch_size=conf['data']['batch_size'],
                            num_workers=conf['data']['num_workers'],
                            drop_last=True)
    # Update number of source values (It depends on the task)
    conf['masknet'].update({'n_src': train_set.n_src})

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.

    loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
    # Checkpointing callback can monitor any quantity which is returned by
    # validation step, defaults to val_loss here (see System).
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_best_only=False)
    # New PL version will come the 7th of december / will have save_top_k
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    config=conf)

    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print('No available GPU were found, set gpus to None')
        conf['main_args']['gpus'] = None
    trainer = pl.Trainer(
        max_nb_epochs=conf['training']['epochs'],
        checkpoint_callback=checkpoint,
        default_save_path=exp_dir,
        gpus=conf['main_args']['gpus'],
        distributed_backend='dp',
        train_percent_check=1.0  # Useful for fast experiment
    )
    trainer.fit(system)
Exemple #13
0
def main(conf):
    best_model_path = os.path.join(conf["exp_dir"], "best_model.pth")
    if not os.path.exists(best_model_path):
        # make pth from checkpoint
        model = load_best_model(conf["train_conf"],
                                conf["exp_dir"],
                                sample_rate=conf["sample_rate"])
        torch.save(model.state_dict(), best_model_path)
    else:
        model, _ = make_model_and_optimizer(conf["train_conf"],
                                            sample_rate=conf["sample_rate"])
        model.eval()
        model.load_state_dict(torch.load(best_model_path))
    # Handle device placement
    if conf["use_gpu"]:
        model.cuda()
    model_device = next(model.parameters()).device
    test_dirs = [
        conf["test_dir"].format(n_src)
        for n_src in conf["train_conf"]["masknet"]["n_srcs"]
    ]
    test_set = Wsj0mixVariable(
        json_dirs=test_dirs,
        n_srcs=conf["train_conf"]["masknet"]["n_srcs"],
        sample_rate=conf["train_conf"]["data"]["sample_rate"],
        seglen=None,
        minlen=None,
    )

    # Randomly choose the indexes of sentences to save.
    ex_save_dir = os.path.join(conf["exp_dir"], "examples/")
    if conf["n_save_ex"] == -1:
        conf["n_save_ex"] = len(test_set)
    save_idx = random.sample(range(len(test_set)), conf["n_save_ex"])
    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
        # Forward the network on the mixture.
        mix, sources = [
            torch.Tensor(x)
            for x in tensors_to_device(test_set[idx], device=model_device)
        ]
        est_sources = model.separate(mix[None])
        p_si_snr = Penalized_PIT_Wrapper(pairwise_neg_sisdr_loss)(est_sources,
                                                                  sources)
        utt_metrics = {
            "P-Si-SNR": p_si_snr.item(),
            "counting_accuracy": float(sources.size(0) == est_sources.size(0)),
        }
        utt_metrics["mix_path"] = test_set.data[idx][0]
        series_list.append(pd.Series(utt_metrics))

        # Save some examples in a folder. Wav files and metrics as text.
        if idx in save_idx:
            mix_np = mix[None].cpu().data.numpy()
            sources_np = sources.cpu().data.numpy()
            est_sources_np = est_sources.cpu().data.numpy()
            local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx))
            os.makedirs(local_save_dir, exist_ok=True)
            sf.write(local_save_dir + "mixture.wav", mix_np[0],
                     conf["sample_rate"])
            # Loop over the sources and estimates
            for src_idx, src in enumerate(sources_np):
                sf.write(local_save_dir + "s{}.wav".format(src_idx + 1), src,
                         conf["sample_rate"])
            for src_idx, est_src in enumerate(est_sources_np):
                sf.write(
                    local_save_dir + "s{}_estimate.wav".format(src_idx + 1),
                    est_src,
                    conf["sample_rate"],
                )
            # Write local metrics to the example folder.
            with open(local_save_dir + "metrics.json", "w") as f:
                json.dump(utt_metrics, f, indent=0)

    # Save all metrics to the experiment folder.
    all_metrics_df = pd.DataFrame(series_list)
    all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv"))

    # Print and save summary metrics
    final_results = {}
    for metric_name in ["P-Si-SNR", "counting_accuracy"]:
        final_results[metric_name] = all_metrics_df[metric_name].mean()
    print("Overall metrics :")
    pprint(final_results)
    with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f:
        json.dump(final_results, f, indent=0)
Exemple #14
0
def main(conf):
    # from asteroid.data.toy_data import WavSet
    # train_set = WavSet(n_ex=1000, n_src=2, ex_len=32000)
    # val_set = WavSet(n_ex=1000, n_src=2, ex_len=32000)
    # Define data pipeline
    train_set = WhamDataset(
        conf["data"]["train_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )
    val_set = WhamDataset(
        conf["data"]["valid_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
    )
    conf["masknet"].update({"n_src": train_set.n_src})

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    # loss_class = PITLossContainer(pairwise_neg_sisdr, n_src=train_set.n_src)
    # Checkpointing callback can monitor any quantity which is returned by
    # validation step, defaults to val_loss here (see System).
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(
        checkpoint_dir, monitor="val_loss", mode="min", save_best_only=False
    )
    # New PL version will come the 7th of december / will have save_top_k
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        config=conf,
    )
    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_nb_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        default_save_path=exp_dir,
        gpus=gpus,
        distributed_backend="dp",
    )
    trainer.fit(system)
Exemple #15
0
def main(conf):
    exp_dir = conf["main_args"]["exp_dir"]
    # Define Dataloader
    train_loader, val_loader = make_dataloaders(**conf["data"], **conf["training"])
    conf["masknet"].update({"n_src": conf["data"]["n_src"]})
    # Define model, optimizer + scheduler
    model, optimizer = make_model_and_optimizer(conf)
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5)

    # Save config
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define loss function
    loss_func = ChimeraLoss(alpha=conf["training"]["loss_alpha"])
    # Put together in System
    system = ChimeraSystem(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(
        checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True
    )
    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss", patience=30, verbose=True)
    gpus = -1
    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print("No available GPU were found, set gpus to None")
        gpus = None

    # Train model
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="dp",
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=200,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)
    # Save last model for convenience
    torch.save(system.model.state_dict(), os.path.join(exp_dir, "checkpoints/final.pth"))
Exemple #16
0
def main(conf):
    # Make the model
    model, _ = make_model_and_optimizer(conf['train_conf'])
    # Load best model
    with open(os.path.join(conf['exp_dir'], 'best_k_models.json'), "r") as f:
        best_k = json.load(f)
    best_model_path = min(best_k, key=best_k.get)
    # Load checkpoint
    checkpoint = torch.load(best_model_path, map_location='cpu')
    state = checkpoint['state_dict']
    state_copy = state.copy()
    # Remove unwanted keys
    for keys, values in state.items():
        if keys.startswith('loss'):
            del state_copy[keys]
            print(keys)
    model = torch_utils.load_state_dict_in(state_copy, model)
    # Handle device placement
    if conf['use_gpu']:
        model.cuda()
    model_device = next(model.parameterss()).device
    test_set = LibriMix(csv_dir=conf['test_dir'],
                        task=conf['task'],
                        sample_rate=conf['sample_rate'],
                        n_src=conf['train_conf']['data']['n_src'],
                        segment=None)  # Uses all segment length
    # Used to reorder sources only
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')

    # Randomly choose the indexes of sentences to save.
    eval_save_dir = os.path.join(conf['exp_dir'], conf['out_dir'])
    ex_save_dir = os.path.join(eval_save_dir, 'examples/')
    if conf['n_save_ex'] == -1:
        conf['n_save_ex'] = len(test_set)
    save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])
    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
        # Forward the network on the mixture.
        mix, sources = tensors_to_device(test_set[idx], device=model_device)
        est_sources = model(mix.unsqueeze(0))
        loss, reordered_sources = loss_func(est_sources,
                                            sources[None],
                                            return_est=True)
        mix_np = mix.cpu().data.numpy()
        sources_np = sources.squeeze().cpu().data.numpy()
        est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
        # For each utterance, we get a dictionary with the mixture path,
        # the input and output metrics
        utt_metrics = get_metrics(mix_np,
                                  sources_np,
                                  est_sources_np,
                                  sample_rate=conf['sample_rate'])
        utt_metrics['mix_path'] = test_set.mixture_path
        series_list.append(pd.Series(utt_metrics))

        # Save some examples in a folder. Wav files and metrics as text.
        if idx in save_idx:
            local_save_dir = os.path.join(ex_save_dir, 'ex_{}/'.format(idx))
            os.makedirs(local_save_dir, exist_ok=True)
            sf.write(local_save_dir + "mixture.wav", mix_np,
                     conf['sample_rate'])
            # Loop over the sources and estimates
            for src_idx, src in enumerate(sources_np):
                sf.write(local_save_dir + "s{}.wav".format(src_idx), src,
                         conf['sample_rate'])
            for src_idx, est_src in enumerate(est_sources_np):
                sf.write(local_save_dir + "s{}_estimate.wav".format(src_idx),
                         est_src, conf['sample_rate'])
            # Write local metrics to the example folder.
            with open(local_save_dir + 'metrics.json', 'w') as f:
                json.dump(utt_metrics, f, indent=0)

    # Save all metrics to the experiment folder.
    all_metrics_df = pd.DataFrame(series_list)
    all_metrics_df.to_csv(os.path.join(eval_save_dir, 'all_metrics.csv'))

    # Print and save summary metrics
    final_results = {}
    for metric_name in compute_metrics:
        input_metric_name = 'input_' + metric_name
        ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
        final_results[metric_name] = all_metrics_df[metric_name].mean()
        final_results[metric_name + '_imp'] = ldf.mean()
    print('Overall metrics :')
    pprint(final_results)
    with open(os.path.join(eval_save_dir, 'final_metrics.json'), 'w') as f:
        json.dump(final_results, f, indent=0)
Exemple #17
0
def main(conf):
    if conf["data"]["data_augmentation"]:
        from local.augmented_wham import AugmentedWhamDataset
        train_set = AugmentedWhamDataset(
            task=conf['data']['task'],
            segment=conf['data']['segment'],
            json_dir=conf["data"]["train_dir"],
            sample_rate=conf['data']['sample_rate'],
            nondefault_nsrc=conf['data']['nondefault_nsrc'],
            **conf["augmentation"])
    else:
        train_set = WhamDataset(
            conf['data']['train_dir'],
            conf['data']['task'],
            sample_rate=conf['data']['sample_rate'],
            segment=conf['data']['segment'],
            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'],
                          conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)
    # Update number of source values (It depends on the task)
    conf['masknet'].update({'n_src': train_set.n_src})

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_top_k=5,
                                 verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss',
                                       patience=30,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_nb_epochs=conf['training']['epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=gpus,
        distributed_backend='dp',
        gradient_clip_val=conf['training']["gradient_clipping"])
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)
Exemple #18
0
def main(conf):
    train_set = LibriMix(csv_dir=conf['data']['train_dir'],
                         task=conf['data']['task'],
                         sample_rate=conf['data']['sample_rate'],
                         n_src=conf['data']['n_src'],
                         segment=conf['data']['segment'])

    val_set = LibriMix(csv_dir=conf['data']['valid_dir'],
                       task=conf['data']['task'],
                       sample_rate=conf['data']['sample_rate'],
                       n_src=conf['data']['n_src'],
                       segment=conf['data']['segment'])

    train_loader = DataLoader(train_set, shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)

    val_loader = DataLoader(val_set, shuffle=True,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)
    conf['masknet'].update({'n_src': conf['data']['n_src']})

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
    system = System(model=model, loss_func=loss_func, optimizer=optimizer,
                    train_loader=train_loader, val_loader=val_loader,
                    scheduler=scheduler, config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
                                 mode='min', save_top_k=5, verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss', patience=10,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print('No available GPU were found, set gpus to None')
        conf['main_args']['gpus'] = None
    trainer = pl.Trainer(max_epochs=conf['training']['epochs'],
                         checkpoint_callback=checkpoint,
                         early_stop_callback=early_stopping,
                         default_save_path=exp_dir,
                         gpus=conf['main_args']['gpus'],
                         distributed_backend='dp',
                         train_percent_check=1.0,  # Useful for fast experiment
                         gradient_clip_val=5.)
    trainer.fit(system)

    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(checkpoint.best_k_models, f, indent=0)
Exemple #19
0
def main(conf):
    train_dirs = [
        conf["data"]["train_dir"].format(n_src)
        for n_src in conf["masknet"]["n_srcs"]
    ]
    valid_dirs = [
        conf["data"]["valid_dir"].format(n_src)
        for n_src in conf["masknet"]["n_srcs"]
    ]
    train_set = Wsj0mixVariable(
        json_dirs=train_dirs,
        n_srcs=conf["masknet"]["n_srcs"],
        sample_rate=conf["data"]["sample_rate"],
        seglen=conf["data"]["seglen"],
        minlen=conf["data"]["minlen"],
    )
    val_set = Wsj0mixVariable(
        json_dirs=valid_dirs,
        n_srcs=conf["masknet"]["n_srcs"],
        sample_rate=conf["data"]["sample_rate"],
        seglen=conf["data"]["seglen"],
        minlen=conf["data"]["minlen"],
    )
    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
        collate_fn=_collate_fn,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
        collate_fn=_collate_fn,
    )
    model, optimizer = make_model_and_optimizer(
        conf, sample_rate=conf["data"]["sample_rate"])
    scheduler = []
    if conf["training"]["half_lr"]:
        scheduler.append(
            ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5))
    if conf["training"]["lr_decay"]:
        scheduler.append(ExponentialLR(optimizer=optimizer, gamma=0.99))
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)
    loss_func = WeightedPITLoss(n_srcs=conf["masknet"]["n_srcs"],
                                lamb=conf["loss"]["lambda"])
    # Put together in System
    system = VarSpkrSystem(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename="{epoch}-{step}",
        monitor="avg_sdr",
        mode="max",
        save_top_k=5,
        verbose=True,
    )
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(
            EarlyStopping(monitor="avg_sdr",
                          mode="max",
                          patience=30,
                          verbose=True))

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    distributed_backend = "dp" if torch.cuda.is_available() else None

    # Train model
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        callbacks=callbacks,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend=distributed_backend,
        limit_train_batches=1.0,  # Useful for fast experiment
        gradient_clip_val=200,
        resume_from_checkpoint=conf["main_args"]["resume_from"],
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)
    # Save last model for convenience
    torch.save(system.model.state_dict(),
               os.path.join(exp_dir, "final_model.pth"))