Example #1
0
# MiniLibriMix is a tiny version of LibriMix (https://github.com/JorisCos/LibriMix),
# which is a free speech separation dataset.
from asteroid.data import LibriMix
# Asteroid's System is a convenience wrapper for PyTorch-Lightning.
from asteroid.engine import System

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--download', type=bool, default=True)
    parser.add_argument('--max_epochs', type=int, default=1)
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--gpus', type=int, default=None)
    args = parser.parse_args()

    # This will automatically download MiniLibriMix from Zenodo on the first run.
    train_loader, val_loader = LibriMix.loaders_from_mini(task="sep_clean",
                                                          batch_size=16)

    # Tell DPRNN that we want to separate to 2 sources.
    model = DPRNNTasNet(n_src=2)

    # PITLossWrapper works with any loss function.
    loss = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    system = System(model, optimizer, loss, train_loader, val_loader)

    # Train for 1 epoch using a single GPU. If you're running this on Google Colab,
    # be sure to select a GPU runtime (Runtime → Change runtime type → Hardware accelarator).
    trainer = Trainer(max_epochs=args.max_epochs, gpus=args.gpus)
    trainer.fit(system)
Example #2
0
def _train(args):
    train_dir = args.train
    val_dir = args.test

    with open('conf.yml') as f:
        def_conf = yaml.safe_load(f)

    pp = argparse.ArgumentParser()
    parser = prepare_parser_from_dict(def_conf, parser=pp)
    arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
    print(arg_dic)
    conf = arg_dic

    train_set = WhamDataset_no_sf(
        train_dir,
        conf['data']['task'],
        sample_rate=conf['data']['sample_rate'],
        segment=conf['data']['segment'],
        nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset_no_sf(
        val_dir,
        conf['data']['task'],
        segment=conf['data']['segment'],
        sample_rate=conf['data']['sample_rate'],
        nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)

    # train_loader = DataLoader(train_set, shuffle=True,
    #                           batch_size=args.batch_size,
    #                           num_workers=conf['training']['num_workers'],
    #                           drop_last=True)
    # val_loader = DataLoader(val_set, shuffle=False,
    #                         batch_size=args.batch_size,
    #                         num_workers=conf['training']['num_workers'],
    #                         drop_last=True)
    # Update number of source values (It depends on the task)
    print("!!!!!!!!!")
    print(train_set.__getitem__(0))
    print(val_set.__getitem__(0))
    print("!!!!!!!!!")
    conf['masknet'].update({'n_src': train_set.n_src})

    model = DPRNNTasNet(**conf['filterbank'], **conf['masknet'])
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    # exp_dir = conf['main_args']['exp_dir']
    # os.makedirs(exp_dir, exist_ok=True)
    exp_dir = args.model_dir
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)
    system.batch_size = 1

    # Define callbacks
    # checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    # checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
    #                              mode='min', save_top_k=5, verbose=1)
    # early_stopping = False
    # if conf['training']['early_stop']:
    #     early_stopping = EarlyStopping(monitor='val_loss', patience=10,
    #                                    verbose=1)

    # Don't ask GPU if they are not available.
    # print("!!!!!!!{}".format(torch.cuda.is_available()))
    # print(torch.__version__)
    gpus = -1 if torch.cuda.is_available() else None
    # trainer = pl.Trainer(max_epochs=conf['training']['epochs'],
    #                      checkpoint_callback=checkpoint,
    #                      early_stop_callback=early_stopping,
    #                      default_root_dir=exp_dir,
    #                      gpus=gpus,
    #                      distributed_backend='ddp',
    #                      gradient_clip_val=conf['training']["gradient_clipping"])
    trainer = pl.Trainer(
        max_epochs=args.epochs,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend='ddp',
        gradient_clip_val=conf['training']["gradient_clipping"])
    trainer.fit(system)
    # print("!!!!!!!!!!!!!!")
    # print(checkpoint)
    # print(checkpoint.best_k_models)
    # print(checkpoint.best_k_models.items())
    # onlyfiles = [f for f in listdir(checkpoint_dir) if isfile(os.path.join(checkpoint_dir, f))]
    # print(onlyfiles)

    # best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    # with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
    #     json.dump(best_k, f, indent=0)

    # # Save best model (next PL version will make this easier)
    # best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    best_path = os.path.join(exp_dir, "__temp_weight_ddp_end.ckpt")
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = system.model.serialize()
    # to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))