示例#1
0
文件: train.py 项目: wyn314/asteroid
def main(conf):
    train_set = WhamDataset(conf['data']['train_dir'],
                            conf['data']['task'],
                            sample_rate=conf['data']['sample_rate'],
                            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'],
                          conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)
    # Update number of source values (It depends on the task)
    conf['masknet'].update({'n_src': train_set.n_src})

    # Define model and optimizer
    model = ConvTasNet(**conf['filterbank'], **conf['masknet'])
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_top_k=5,
                                 verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss',
                                       patience=10,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf['training']['epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=gpus,
        distributed_backend='dp',
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.)
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    # Save best model (next PL version will make this easier)
    best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))
示例#2
0
def main(conf):
    train_set = SATBDataset(
        conf["data"]["train_dir"],
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )
    val_set = SATBDataset(
        conf["data"]["train_dir"],
        partition="valid",
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    # Update number of source values (It depends on the task)
    conf["masknet"].update({"n_src": train_set.n_src})

    # Define model and optimizer
    model = ConvTasNet(**conf["filterbank"], **conf["masknet"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss",
                                       patience=30,
                                       verbose=True)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="dp",
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )
    trainer.fit(system.float())

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
示例#3
0
def main(conf):
    exp_dir = conf["main_args"]["exp_dir"]
    # Define Dataloader
    """total_set = MedleydbDataset(
        conf["data"]["json_dir"],
        n_src=conf["data"]["n_inst"],
        n_poly=conf["data"]["n_poly"],
        sample_rate=conf["data"]["sample_rate"],
        segment=conf["data"]["segment"],
        threshold=conf["data"]["threshold"],
    )
"""
    train_set = SourceFolderDataset(
        train_dir,
        train_dir,
        conf["data"]["n_poly"],
        conf["data"]["sample_rate"],
        conf["training"]["batch_size"],
    )
    val_set = SourceFolderDataset(
        val_dir,
        val_dir,
        conf["data"]["n_poly"],
        conf["data"]["sample_rate"],
        conf["training"]["batch_size"],
    )
    #validation_size = int(conf["data"]["validation_split"] * len(total_set))
    #train_size = len(total_set) - validation_size
    #torch.manual_seed(conf["training"]["random_seed"])
    #train_set, val_set = data.random_split(total_set, [train_size, validation_size])

    train_loader = data.DataLoader(
        train_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = data.DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )

    conf["masknet"].update(
        {"n_src": conf["data"]["n_inst"] * conf["data"]["n_poly"]})

    # Define model and optimizer
    model = ConvTasNet(**conf["filterbank"],
                       **conf["masknet"],
                       sample_rate=conf["data"]["sample_rate"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=3)
        #monitor = "val_loss"
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    #loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral, pit_from="pw_pt")
    #loss_func = torch.nn.L1Loss()
    #loss_func = pairwise_neg_sisdr

    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )
    '''
    # Callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(
        checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=1
    )
    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss", patience=15, verbose=1)
    '''
    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(
            EarlyStopping(monitor="val_loss",
                          mode="min",
                          patience=30,
                          verbose=True))

    #gpus = -1
    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        callbacks=callbacks,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="ddp",
        limit_train_batches=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )

    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    # Save best model (next PL version will make this easier)
    best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
示例#4
0
文件: train.py 项目: xmpx/asteroid
def main(conf):
    # Define dataloader using ORIGINAL mixture.
    dataset_kwargs = {
        "root_path": Path(conf["data"]["root_path"]),
        "sample_rate": conf["data"]["sample_rate"],
        "num_workers": conf["training"]["num_workers"],
        "mixture": conf["data"]["mixture"],
        "task": conf["data"]["task"],
    }

    train_set = DAMPVSEPSinglesDataset(
        split=f"train_{conf['data']['train_set']}",
        random_segments=True,
        segment=conf["data"]["segment"],
        ex_per_track=conf["data"]["ex_per_track"],
        **dataset_kwargs,
    )

    val_set = DAMPVSEPSinglesDataset(split="valid", **dataset_kwargs)

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )

    val_loader = DataLoader(
        val_set, shuffle=False, batch_size=1, num_workers=conf["training"]["num_workers"]
    )

    model = ConvTasNet(**conf["filterbank"], **conf["masknet"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])

    # Define scheduler
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    # Combine_Loss is not complete. Needs improvement
    # loss_func = Combine_Loss(alpha=conf['training']['loss_alpha'],
    #                          sample_rate=conf['data']['sample_rate'])
    loss_func = torch.nn.L1Loss()
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(
        checkpoint_dir, monitor="val_loss", mode="min", save_top_k=10, verbose=True
    )

    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss", patience=20, verbose=True)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="ddp",
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
示例#5
0
        te_x = mix_signals_batch(te_s, te_n, args.snr_ranges)

        # Load student model for personalization
        if args.is_g_ctn:
            from asteroid.models import ConvTasNet
            G_model = ConvTasNet(n_src=1)
            ctn_dir = "new_models_results/Gs/expr04221819_SE_G3x1024_lr5e-04_bs20_ctnTruesm-1_nfrms16000_GPU1/"
            G_model.load_state_dict(
                torch.load("{}/Dmodel_best.pt".format(ctn_dir)))
        else:
            G_model = SpeechEnhancementModel(args.G_hidden_size,
                                             args.G_num_layers,
                                             args.stft_features)
            load_model(G_model, args.load_SEmodel)
        G_model = G_model.to(args.device)
        G_optimizer = torch.optim.Adam(params=G_model.parameters(),
                                       lr=args.learning_rate)

        if args.is_wgan:
            D_model = Discriminator_RNN_Utt_Jit(args.D_hidden_size,
                                                args.D_num_layers,
                                                args.stft_features,
                                                args.stft_frames)
        else:
            D_model = Discriminator_RNN_BC(args.D_hidden_size,
                                           args.D_num_layers,
                                           args.stft_features,
                                           args.stft_frames)
        D_optimizer = torch.optim.Adam(
            params=D_model.parameters(),
            lr=args.learning_rate * 4 if args.is_ttur else args.learning_rate)