Beispiel #1
0
def main(conf):
    train_set = WhamDataset(
        conf["data"]["train_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        segment=conf["data"]["segment"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )
    val_set = WhamDataset(
        conf["data"]["valid_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    # Update number of source values (It depends on the task)
    conf["masknet"].update({"n_src": train_set.n_src})

    model = DPTNet(**conf["filterbank"], **conf["masknet"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    from asteroid.engine.schedulers import DPTNetScheduler

    schedulers = {
        "scheduler":
        DPTNetScheduler(optimizer,
                        len(train_loader) // conf["training"]["batch_size"],
                        64),
        "interval":
        "step",
    }

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        scheduler=schedulers,
        train_loader=train_loader,
        val_loader=val_loader,
        config=conf,
    )

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss",
                                       patience=30,
                                       verbose=True)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="ddp",
        gradient_clip_val=conf["training"]["gradient_clipping"],
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
Beispiel #2
0
def main(conf):
    train_set = LibriMix(
        csv_dir=conf["data"]["train_dir"],
        task=conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        n_src=conf["data"]["n_src"],
        segment=conf["data"]["segment"],
    )

    val_set = LibriMix(
        csv_dir=conf["data"]["valid_dir"],
        task=conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        n_src=conf["data"]["n_src"],
        segment=conf["data"]["segment"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )

    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    conf["masknet"].update({"n_src": conf["data"]["n_src"]})

    model = ConvTasNet(**conf["filterbank"], **conf["masknet"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(
        checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True
    )
    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss", patience=30, verbose=True)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="dp",
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
Beispiel #3
0
def main(conf):
    train_set = WhamDataset(conf['data']['train_dir'],
                            conf['data']['task'],
                            sample_rate=conf['data']['sample_rate'],
                            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'],
                          conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)
    # Update number of source values (It depends on the task)
    conf['masknet'].update({'n_src': train_set.n_src})

    # Define model and optimizer
    model = ConvTasNet(**conf['filterbank'], **conf['masknet'])
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_top_k=5,
                                 verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss',
                                       patience=10,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf['training']['epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=gpus,
        distributed_backend='dp',
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.)
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    # Save best model (next PL version will make this easier)
    best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))
Beispiel #4
0
def main(conf):
    # Define dataloader using ORIGINAL mixture.
    dataset_kwargs = {
        "root_path": Path(conf["data"]["root_path"]),
        "sample_rate": conf["data"]["sample_rate"],
        "num_workers": conf["training"]["num_workers"],
        "mixture": conf["data"]["mixture"],
        "task": conf["data"]["task"],
    }

    train_set = DAMPVSEPSinglesDataset(
        split=f"train_{conf['data']['train_set']}",
        random_segments=True,
        segment=conf["data"]["segment"],
        ex_per_track=conf["data"]["ex_per_track"],
        **dataset_kwargs,
    )

    val_set = DAMPVSEPSinglesDataset(split="valid", **dataset_kwargs)

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )

    val_loader = DataLoader(
        val_set, shuffle=False, batch_size=1, num_workers=conf["training"]["num_workers"]
    )

    model = ConvTasNet(**conf["filterbank"], **conf["masknet"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])

    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    # Combine_Loss is not complete. Needs improvement
    # loss_func = Combine_Loss(alpha=conf['training']['loss_alpha'],
    #                          sample_rate=conf['data']['sample_rate'])
    loss_func = torch.nn.L1Loss()
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(
        checkpoint_dir, monitor="val_loss", mode="min", save_top_k=10, verbose=True
    )

    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss", patience=20, verbose=True)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="ddp",
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
Beispiel #5
0
def _train(args):
    train_dir = args.train
    val_dir = args.test

    with open('conf.yml') as f:
        def_conf = yaml.safe_load(f)

    pp = argparse.ArgumentParser()
    parser = prepare_parser_from_dict(def_conf, parser=pp)
    arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
    print(arg_dic)
    conf = arg_dic

    train_set = WhamDataset_no_sf(
        train_dir,
        conf['data']['task'],
        sample_rate=conf['data']['sample_rate'],
        segment=conf['data']['segment'],
        nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset_no_sf(
        val_dir,
        conf['data']['task'],
        segment=conf['data']['segment'],
        sample_rate=conf['data']['sample_rate'],
        nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)

    # train_loader = DataLoader(train_set, shuffle=True,
    #                           batch_size=args.batch_size,
    #                           num_workers=conf['training']['num_workers'],
    #                           drop_last=True)
    # val_loader = DataLoader(val_set, shuffle=False,
    #                         batch_size=args.batch_size,
    #                         num_workers=conf['training']['num_workers'],
    #                         drop_last=True)
    # Update number of source values (It depends on the task)
    print("!!!!!!!!!")
    print(train_set.__getitem__(0))
    print(val_set.__getitem__(0))
    print("!!!!!!!!!")
    conf['masknet'].update({'n_src': train_set.n_src})

    model = DPRNNTasNet(**conf['filterbank'], **conf['masknet'])
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    # exp_dir = conf['main_args']['exp_dir']
    # os.makedirs(exp_dir, exist_ok=True)
    exp_dir = args.model_dir
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)
    system.batch_size = 1

    # Define callbacks
    # checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    # checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
    #                              mode='min', save_top_k=5, verbose=1)
    # early_stopping = False
    # if conf['training']['early_stop']:
    #     early_stopping = EarlyStopping(monitor='val_loss', patience=10,
    #                                    verbose=1)

    # Don't ask GPU if they are not available.
    # print("!!!!!!!{}".format(torch.cuda.is_available()))
    # print(torch.__version__)
    gpus = -1 if torch.cuda.is_available() else None
    # trainer = pl.Trainer(max_epochs=conf['training']['epochs'],
    #                      checkpoint_callback=checkpoint,
    #                      early_stop_callback=early_stopping,
    #                      default_root_dir=exp_dir,
    #                      gpus=gpus,
    #                      distributed_backend='ddp',
    #                      gradient_clip_val=conf['training']["gradient_clipping"])
    trainer = pl.Trainer(
        max_epochs=args.epochs,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend='ddp',
        gradient_clip_val=conf['training']["gradient_clipping"])
    trainer.fit(system)
    # print("!!!!!!!!!!!!!!")
    # print(checkpoint)
    # print(checkpoint.best_k_models)
    # print(checkpoint.best_k_models.items())
    # onlyfiles = [f for f in listdir(checkpoint_dir) if isfile(os.path.join(checkpoint_dir, f))]
    # print(onlyfiles)

    # best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    # with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
    #     json.dump(best_k, f, indent=0)

    # # Save best model (next PL version will make this easier)
    # best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    best_path = os.path.join(exp_dir, "__temp_weight_ddp_end.ckpt")
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = system.model.serialize()
    # to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))
Beispiel #6
0
def main(conf):
    # train_set = WhamDataset(
    #     conf["data"]["train_dir"],
    #     conf["data"]["task"],
    #     sample_rate=conf["data"]["sample_rate"],
    #     segment=conf["data"]["segment"],
    #     nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    # )
    # val_set = WhamDataset(
    #     conf["data"]["valid_dir"],
    #     conf["data"]["task"],
    #     sample_rate=conf["data"]["sample_rate"],
    #     nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    # )
    train_set = LibriMix(
        csv_dir=conf["data"]["train_dir"],
        task=conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        n_src=conf["masknet"]["n_src"],
        segment=conf["data"]["segment"],
    )

    val_set = LibriMix(
        csv_dir=conf["data"]["valid_dir"],
        task=conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        n_src=conf["masknet"]["n_src"],
        segment=conf["data"]["segment"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    # Update number of source values (It depends on the task)
    # TODO: redundant
    conf["masknet"].update({"n_src": train_set.n_src})

    model = DPRNNTasNet(**conf["filterbank"],
                        **conf["masknet"],
                        sample_rate=conf['data']['sample_rate'])

    # from torchsummary import summary
    # model.cuda()
    # summary(model, (24000,))
    # import pdb
    # pdb.set_trace()

    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(
            EarlyStopping(monitor="val_loss",
                          mode="min",
                          patience=30,
                          verbose=True))

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    distributed_backend = "ddp" if torch.cuda.is_available() else None

    if conf["training"]["cont"]:
        from glob import glob
        ckpts = glob('%s/*.ckpt' % checkpoint_dir)
        ckpts.sort()
        latest_ckpt = ckpts[-1]
        trainer = pl.Trainer(
            max_epochs=conf["training"]["epochs"],
            callbacks=callbacks,
            default_root_dir=exp_dir,
            gpus=gpus,
            distributed_backend=distributed_backend,
            limit_train_batches=1.0,  # Useful for fast experiment
            gradient_clip_val=conf["training"]["gradient_clipping"],
            resume_from_checkpoint=latest_ckpt)
    else:
        trainer = pl.Trainer(
            max_epochs=conf["training"]["epochs"],
            callbacks=callbacks,
            default_root_dir=exp_dir,
            gpus=gpus,
            distributed_backend=distributed_backend,
            limit_train_batches=1.0,  # Useful for fast experiment
            gradient_clip_val=conf["training"]["gradient_clipping"],
        )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    # Save best model (next PL version will make this easier)
    # best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    # state_dict = torch.load(best_path)
    state_dict = torch.load(checkpoint.best_model_path)
    # state_dict = torch.load('exp/train_dprnn_130d5f9a/checkpoints/epoch=154.ckpt')
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
Beispiel #7
0
def main(conf):
    exp_dir = conf["main_args"]["exp_dir"]
    # Define Dataloader
    """
    total_set = MedleydbDataset(
        conf["data"]["json_dir"],
        n_src=conf["data"]["n_inst"],
        n_poly=conf["data"]["n_poly"],
        sample_rate=conf["data"]["sample_rate"],
        segment=conf["data"]["segment"],
        threshold=conf["data"]["threshold"],
    )

    validation_size = int(conf["data"]["validation_split"] * len(total_set))
    train_size = len(total_set) - validation_size
    torch.manual_seed(conf["training"]["random_seed"])
    train_set, val_set = data.random_split(total_set, [train_size, validation_size])
    """
    train_set = SourceFolderDataset(
        train_dir,
        train_dir,
        conf["data"]["n_poly"],
        conf["data"]["sample_rate"],
        conf["training"]["batch_size"],
    )
    val_set = SourceFolderDataset(
        val_dir,
        val_dir,
        conf["data"]["n_poly"],
        conf["data"]["sample_rate"],
        conf["training"]["batch_size"],
    )
    
    train_loader = data.DataLoader(
        train_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = data.DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )    # Update number of source values (It depends on the task)
    conf["masknet"].update({"n_src": conf["data"]["n_inst"] * conf["data"]["n_poly"]})

    model = DPRNNTasNet(**conf["filterbank"], **conf["masknet"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(
        checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True
    )
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(EarlyStopping(monitor="val_loss", mode="min", patience=30, verbose=True))

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        callbacks=callbacks,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="ddp",
        gradient_clip_val=conf["training"]["gradient_clipping"],
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
Beispiel #8
0
def main(conf):
    train_set = WhamDataset(
        conf["data"]["train_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        segment=conf["data"]["segment"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )
    val_set = WhamDataset(
        conf["data"]["valid_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    # Update number of source values (It depends on the task)
    conf["masknet"].update({"n_src": train_set.n_src})

    model = DPRNNTasNet(**conf["filterbank"], **conf["masknet"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 verbose=True,
                                 mode='min',
                                 save_top_k=5)

    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss",
                                       patience=30,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf['training']['epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend='ddp',
        gradient_clip_val=conf['training']["gradient_clipping"])
    trainer.fit(system)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
Beispiel #9
0
def main(conf):

    train_set = LibriVADDataset(md_file_path=conf["data"]["train_dir"])
    val_set = LibriVADDataset(md_file_path=conf["data"]["valid_dir"])
    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )

    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )

    model = VADNet(**conf["filterbank"], **conf["masknet"])

    optimizer = make_optimizer(model.parameters(), **conf["optim"])

    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = F1_loss()
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(
            EarlyStopping(monitor="val_loss",
                          mode="min",
                          patience=30,
                          verbose=True))

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    distributed_backend = "ddp" if torch.cuda.is_available() else None

    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        callbacks=callbacks,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend=distributed_backend,
        # limit_train_batches=0.0002,  # Useful for fast experiment
        # limit_val_batches=0.0035,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
def main(conf):
    train_enh_dir = conf["main_args"].get("train_enh_dir", None)
    resume_ckpt = conf["main_args"].get("resume_ckpt", None)

    train_loader, val_loader, train_set_infos = make_dataloaders(
        corpus=conf["main_args"]["corpus"],
        train_dir=conf["data"]["train_dir"],
        val_dir=conf["data"]["valid_dir"],
        train_enh_dir=train_enh_dir,
        task=conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        n_src=conf["data"]["n_src"],
        segment=conf["data"]["segment"],
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
    )

    if conf["main_args"]["strategy"] != "multi_task":
        conf["masknet"].update({"n_src": conf["data"]["n_src"]})
    else:
        conf["masknet"].update({"n_src": conf["data"]["n_src"] + 1})

    model = getattr(asteroid.models,
                    conf["main_args"]["model"])(**conf["filterbank"],
                                                **conf["masknet"])

    if conf["main_args"]["strategy"] == "pretrained":
        if conf["main_args"]["load_path"] is not None:
            all_states = torch.load(conf["main_args"]["load_path"],
                                    map_location="cpu")
            assert "state_dict" in all_states

            # If the checkpoint is not the serialized "best_model.pth", its keys
            # would start with "model.", which should be removed to avoid none
            # of the parameters are loaded.
            for key in list(all_states["state_dict"].keys()):
                if key.startswith("model"):
                    all_states["state_dict"][key.split(
                        '.', 1)[1]] = all_states["state_dict"][key]
                    del all_states["state_dict"][key]

            # For debugging, set strict=True to check whether only the following
            # parameters have different sizes (since n_src=1 for pre-training
            # and n_src=2 for fine-tuning):
            # for ConvTasNet: "masker.mask_net.1.*"
            # for DPRNNTasNet/DPTNet: "masker.first_out.1.*"
            if conf["main_args"]["model"] == "ConvTasNet":
                del all_states["state_dict"]["masker.mask_net.1.weight"]
                del all_states["state_dict"]["masker.mask_net.1.bias"]
            elif conf["main_args"]["model"] in ["DPRNNTasNet", "DPTNet"]:
                del all_states["state_dict"]["masker.first_out.1.weight"]
                del all_states["state_dict"]["masker.first_out.1.bias"]
            model.load_state_dict(all_states["state_dict"], strict=False)

    optimizer = make_optimizer(model.parameters(), **conf["optim"])

    # Define scheduler
    scheduler = None
    if conf["main_args"]["model"] in [
            "DPTNet", "SepFormerTasNet", "SepFormer2TasNet"
    ]:
        steps_per_epoch = len(
            train_loader) // conf["main_args"]["accumulate_grad_batches"]
        conf["scheduler"]["steps_per_epoch"] = steps_per_epoch
        scheduler = {
            "scheduler":
            DPTNetScheduler(
                optimizer=optimizer,
                steps_per_epoch=steps_per_epoch,
                d_model=model.masker.mha_in_dim,
            ),
            "interval":
            "batch",
        }
    elif conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    if conf["main_args"]["strategy"] == "multi_task":
        loss_func = MultiTaskLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    else:
        loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename='{epoch}-{step}',
        monitor="val_loss",
        mode="min",
        save_top_k=conf["training"]["epochs"],
        save_last=True,
        verbose=True,
    )
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(
            EarlyStopping(monitor="val_loss",
                          mode="min",
                          patience=30,
                          verbose=True))

    loggers = []
    tb_logger = pl.loggers.TensorBoardLogger(os.path.join(exp_dir,
                                                          "tb_logs/"), )
    loggers.append(tb_logger)
    if conf["main_args"]["comet"]:
        comet_logger = pl.loggers.CometLogger(
            save_dir=os.path.join(exp_dir, "comet_logs/"),
            experiment_key=conf["main_args"].get("comet_exp_key", None),
            log_code=True,
            log_graph=True,
            parse_args=True,
            log_env_details=True,
            log_git_metadata=True,
            log_git_patch=True,
            log_env_gpu=True,
            log_env_cpu=True,
            log_env_host=True,
        )
        comet_logger.log_hyperparams(conf)
        loggers.append(comet_logger)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    distributed_backend = "ddp" if torch.cuda.is_available(
    ) else None  # Don't use ddp for multi-task training

    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        logger=loggers,
        callbacks=callbacks,
        # checkpoint_callback=checkpoint,
        # early_stop_callback=callbacks[1],
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend=distributed_backend,
        limit_train_batches=1.0,  # Useful for fast experiment
        # fast_dev_run=True, # Useful for debugging
        # overfit_batches=0.001, # Useful for debugging
        gradient_clip_val=5.0,
        accumulate_grad_batches=conf["main_args"]["accumulate_grad_batches"],
        resume_from_checkpoint=resume_ckpt,
        deterministic=True,
        replace_sampler_ddp=False
        if conf["main_args"]["strategy"] == "multi_task" else True,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set_infos)
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
Beispiel #11
0
def main(conf):
    # train_set = WhamDataset(
    #     conf["data"]["train_dir"],
    #     conf["data"]["task"],
    #     sample_rate=conf["data"]["sample_rate"],
    #     segment=conf["data"]["segment"],
    #     nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    # )
    # val_set = WhamDataset(
    #     conf["data"]["valid_dir"],
    #     conf["data"]["task"],
    #     sample_rate=conf["data"]["sample_rate"],
    #     nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    # )
    train_set = LibriMix(
        csv_dir=conf["data"]["train_dir"],
        task=conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        n_src=conf["masknet"]["n_src"],
        segment=conf["data"]["segment"],
    )
    print(conf["data"]["train_dir"])

    val_set = LibriMix(
        csv_dir=conf["data"]["valid_dir"],
        task=conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        n_src=conf["masknet"]["n_src"],
        segment=conf["data"]["segment"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    # Update number of source values (It depends on the task)
    conf["masknet"].update({"n_src": train_set.n_src})

    # TODO params
    # model = TransMask(**conf["filterbank"], **conf["masknet"])
    model = DPTrans(**conf["filterbank"], **conf["masknet"], sample_rate=conf['data']['sample_rate'])

    # from torchsummary import summary
    # model.cuda()
    # summary(model, (24000,))
    # import pdb
    # pdb.set_trace()

    optimizer = make_optimizer(model.parameters(), **conf["optim"])

    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5)

    # # TODO warmup for transformer
    # from asteroid.engine.schedulers import DPTNetScheduler
    # schedulers = {
    #     "scheduler": DPTNetScheduler(
    #         # optimizer, len(train_loader) // conf["training"]["batch_size"], 64
    #         # optimizer, len(train_loader), 64,
    #         optimizer, len(train_loader), 128,
    #         stride=2,
    #         # exp_max=0.0004 * 16,
    #         # warmup_steps=1000
    #     ),
    #     "interval": "batch",
    # }

    # from torch.optim.lr_scheduler import ReduceLROnPlateau
    # if conf["training"]["half_lr"]:
    #     print('Use ReduceLROnPlateau halflr...........')
    #     schedulers = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5)

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    # checkpoint_dir = os.path.join(exp_dir)
    checkpoint = ModelCheckpoint(
        checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True
    )
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(EarlyStopping(monitor="val_loss", mode="min", patience=30, verbose=True))

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    distributed_backend = "ddp" if torch.cuda.is_available() else None

    if conf["training"]["cont"]:
        from glob import glob
        ckpts = glob('%s/*.ckpt' % checkpoint_dir)
        ckpts.sort()
        latest_ckpt = ckpts[-1]
        trainer = pl.Trainer(
            max_epochs=conf["training"]["epochs"],
            callbacks=callbacks,
            default_root_dir=exp_dir,
            gpus=gpus,
            distributed_backend=distributed_backend,
            limit_train_batches=1.0,  # Useful for fast experiment
            gradient_clip_val=conf["training"]["gradient_clipping"],
            resume_from_checkpoint=latest_ckpt
        )
    else:
        trainer = pl.Trainer(
            max_epochs=conf["training"]["epochs"],
            callbacks=callbacks,
            default_root_dir=exp_dir,
            gpus=gpus,
            distributed_backend=distributed_backend,
            limit_train_batches=1.0,  # Useful for fast experiment
            gradient_clip_val=conf["training"]["gradient_clipping"],
        )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    # state_dict = torch.load('exp/train_transmask_rnn_acous_gelu_6layer_peconv_stride2_batch6/_ckpt_epoch_208.ckpt')
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
Beispiel #12
0
def main(conf):
    train_set = PodcastMixDataloader(
        csv_dir=conf["data"]["train_dir"],
        sample_rate=conf["data"]["sample_rate"],
        original_sample_rate=conf["data"]["original_sample_rate"],
        segment=conf["data"]["segment"],
        shuffle_tracks=True,
        multi_speakers=conf["training"]["multi_speakers"])
    val_set = PodcastMixDataloader(
        csv_dir=conf["data"]["valid_dir"],
        sample_rate=conf["data"]["sample_rate"],
        original_sample_rate=conf["data"]["original_sample_rate"],
        segment=conf["data"]["segment"],
        shuffle_tracks=True,
        multi_speakers=conf["training"]["multi_speakers"])
    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf["training"]["batch_size"],
                              num_workers=conf["training"]["num_workers"],
                              drop_last=True,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf["training"]["batch_size"],
                            num_workers=conf["training"]["num_workers"],
                            drop_last=True,
                            pin_memory=True)

    if (conf["model"]["name"] == "ConvTasNet"):
        sys.path.append('ConvTasNet_model')
        from conv_tasnet_norm import ConvTasNetNorm
        conf["masknet"].update({"n_src": conf["data"]["n_src"]})
        model = ConvTasNetNorm(**conf["filterbank"],
                               **conf["masknet"],
                               sample_rate=conf["data"]["sample_rate"])
        loss_func = LogL2Time()
        plugins = None
    elif (conf["model"]["name"] == "UNet"):
        # UNet with logl2 time loss and normalization inside model
        sys.path.append('UNet_model')
        from unet_model import UNet
        model = UNet(conf["data"]["sample_rate"], conf["data"]["fft_size"],
                     conf["data"]["hop_size"], conf["data"]["window_size"],
                     conf["convolution"]["kernel_size"],
                     conf["convolution"]["stride"])
        loss_func = LogL2Time()
        plugins = DDPPlugin(find_unused_parameters=False)
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["model"]["name"] + "_model/" + conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(
            EarlyStopping(monitor="val_loss",
                          mode="min",
                          patience=100,
                          verbose=True))

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    distributed_backend = "ddp" if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        callbacks=callbacks,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend=distributed_backend,
        gradient_clip_val=5.0,
        resume_from_checkpoint=conf["main_args"]["resume_from"],
        precision=32,
        plugins=plugins)
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        print(best_k, f)
        json.dump(best_k, f, indent=0)
    print(checkpoint.best_model_path)
    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))