예제 #1
0
파일: run.py 프로젝트: yongjiankuang/onssen
def main():
    parser = argparse.ArgumentParser(description='Parse the config path')
    parser.add_argument(
        "-c",
        "--config",
        dest="path",
        help=
        'The path to the config file. e.g. python run.py --config dc_config.json'
    )

    config = parser.parse_args()
    with open(config.path) as f:
        args = json.load(f)
        args = AttrDict(args)
    device = torch.device(args.device)
    args.model = onssen.nn.chimera(args.model_options)
    args.model.to(device)
    args.train_loader = data.edinburgh_tts_dataloader(args.model_name,
                                                      args.feature_options,
                                                      'train',
                                                      args.cuda_option,
                                                      self.device)
    args.valid_loader = data.edinburgh_tts_dataloader(args.model_name,
                                                      args.feature_options,
                                                      'validation',
                                                      args.cuda_option,
                                                      self.device)
    args.optimizer = utils.build_optimizer(args.model.parameters(),
                                           args.optimizer_options)
    args.loss_fn = loss.loss_chimera_psa
    trainer = onssen.utils.trainer(args)
    trainer.run()

    tester = onssen.utils.tester(args)
    tester.eval()
예제 #2
0
파일: run.py 프로젝트: yongjiankuang/onssen
def main():
    parser = argparse.ArgumentParser(description='Parse the config path')
    parser.add_argument(
        "-c",
        "--config",
        dest="path",
        help=
        'The path to the config file. e.g. python run.py --config onfig.json')

    config = parser.parse_args()
    with open(config.path) as f:
        args = json.load(f)
        args = AttrDict(args)
    device = torch.device(args.device)
    args.model = nn.deep_clustering(**(args['model_options']))
    args.model.to(device)
    args.train_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                  args.feature_options, 'tr',
                                                  device)
    args.valid_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                  args.feature_options, 'cv',
                                                  device)
    args.test_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                 args.feature_options, 'tt',
                                                 device)
    args.optimizer = utils.build_optimizer(args.model.parameters(),
                                           args.optimizer_options)
    args.loss_fn = loss.loss_dc
    trainer = utils.trainer(args)
    trainer.run()

    tester = tester_dc(args)
    tester.eval()
예제 #3
0
파일: run.py 프로젝트: yongjiankuang/onssen
def main():
    config_path = './config.json'
    with open(config_path) as f:
        args = json.load(f)
        args = AttrDict(args)
    device = torch.device(args.device)
    args.device = device
    args.model = nn.ConvTasNet(**args["model_options"])
    args.model.to(device)
    args.train_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                  args.feature_options, 'tr',
                                                  device)
    args.valid_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                  args.feature_options, 'cv',
                                                  device)
    args.test_loader = data.wsj0_2mix_dataloader(args.model_name,
                                                 args.feature_options, 'tt',
                                                 device)
    args.optimizer = utils.build_optimizer(args.model.parameters(),
                                           args.optimizer_options)
    args.loss_fn = loss.si_snr_loss
    trainer = utils.trainer(args)
    trainer.run()
    tester = tester_tasnet(args)
    tester.eval()
예제 #4
0
def main():
    config_path = './config.json'
    with open(config_path) as f:
        args = json.load(f)
        args = AttrDict(args)
    device = torch.device(args.device)
    args.model = nn.chimera(**(args['model_options']))
    args.model.to(device)
    args.train_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tr', device)
    args.valid_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'cv', device)
    args.test_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tt', device)
    args.optimizer = utils.build_optimizer(args.model.parameters(), args.optimizer_options)
    args.loss_fn = loss.loss_chimera_msa
    trainer = utils.trainer(args)
    trainer.run()
    tester = tester_chimera(args)
    tester.eval()