Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(description='Parse the config path')
    parser.add_argument(
        "-c",
        "--config",
        dest="path",
        help=
        'The path to the config file. e.g. python run.py --config onfig.json')

    config = parser.parse_args()
    with open(config.path) as f:
        args = json.load(f)
        args = AttrDict(args)
    device = torch.device(args.device)
    args.model = nn.deep_clustering(**(args['model_options']))
    args.model.to(device)
    args.train_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                  args.feature_options, 'tr',
                                                  device)
    args.valid_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                  args.feature_options, 'cv',
                                                  device)
    args.test_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                 args.feature_options, 'tt',
                                                 device)
    args.optimizer = utils.build_optimizer(args.model.parameters(),
                                           args.optimizer_options)
    args.loss_fn = loss.loss_dc
    trainer = utils.trainer(args)
    trainer.run()

    tester = tester_dc(args)
    tester.eval()
Ejemplo n.º 2
0
def main():
    config_path = './config.json'
    with open(config_path) as f:
        args = json.load(f)
        args = AttrDict(args)
    device = torch.device(args.device)
    args.device = device
    args.model = nn.ConvTasNet(**args["model_options"])
    args.model.to(device)
    args.train_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                  args.feature_options, 'tr',
                                                  device)
    args.valid_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                  args.feature_options, 'cv',
                                                  device)
    args.test_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                 args.feature_options, 'tt',
                                                 device)
    args.optimizer = utils.build_optimizer(args.model.parameters(),
                                           args.optimizer_options)
    args.loss_fn = loss.si_snr_loss
    trainer = utils.trainer(args)
    trainer.run()
    tester = tester_tasnet(args)
    tester.eval()
Ejemplo n.º 3
0
def main():
    config_path = './config.json'
    with open(config_path) as f:
        args = json.load(f)
        args = AttrDict(args)
    device = torch.device(args.device)
    args.model = nn.chimera(**(args['model_options']))
    args.model.to(device)
    args.train_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tr', device)
    args.valid_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'cv', device)
    args.test_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tt', device)
    args.optimizer = utils.build_optimizer(args.model.parameters(), args.optimizer_options)
    args.loss_fn = loss.loss_chimera_msa
    trainer = utils.trainer(args)
    trainer.run()
    tester = tester_chimera(args)
    tester.eval()