Пример #1
0
def test_boolean(parser):
    fake_args = ["--key3", "y"]
    arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True, args=fake_args)
    assert plain_args.key3 is True

    fake_args = ["--key3", "n"]
    arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True, args=fake_args)
    assert plain_args.key3 is False
Пример #2
0
def test_boolean(parser):
    fake_args = ['--key3', 'y']
    arg_dic, plain_args = parse_args_as_dict(parser,
                                             return_plain_args=True,
                                             args=fake_args)
    assert plain_args.key3 is True

    fake_args = ['--key3', 'n']
    arg_dic, plain_args = parse_args_as_dict(parser,
                                             return_plain_args=True,
                                             args=fake_args)
    assert plain_args.key3 is False
Пример #3
0
def test_namespace_dic(parser):
    fake_args = ['--key2', 'hey', '--key3', '0']
    arg_dic, plain_args = parse_args_as_dict(parser,
                                             return_plain_args=True,
                                             args=fake_args)
    assert arg_dic['main_args']['main_key'] == plain_args.main_key
    assert arg_dic['top2']['key3'] == plain_args.key3
Пример #4
0
def test_namespace_dic(parser):
    fake_args = ["--key2", "hey", "--key3", "0"]
    arg_dic, plain_args = parse_args_as_dict(parser,
                                             return_plain_args=True,
                                             args=fake_args)
    assert arg_dic["main_args"]["main_key"] == plain_args.main_key
    assert arg_dic["top2"]["key3"] == plain_args.key3
Пример #5
0
def test_none_default(parser, inp):
    # If the default is None, convert the input string into an int, a float
    # or string.
    fake_args = ['--key2', str(inp)]  # Note : inp is converted to string
    arg_dic, plain_args = parse_args_as_dict(parser,
                                             return_plain_args=True,
                                             args=fake_args)
    assert type(plain_args.key2) == type(inp)
Пример #6
0
    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))


if __name__ == "__main__":
    import yaml
    from pprint import pprint
    from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict

    # We start with opening the config file conf.yml as a dictionary from
    # which we can create parsers. Each top level key in the dictionary defined
    # by the YAML file creates a group in the parser.
    with open("local/conf.yml") as f:
        def_conf = yaml.safe_load(f)
    parser = prepare_parser_from_dict(def_conf, parser=parser)
    # Arguments are then parsed into a hierarchical dictionary (instead of
    # flat, as returned by argparse) to facilitate calls to the different
    # asteroid methods (see in main).
    # plain_args is the direct output of parser.parse_args() and contains all
    # the attributes in an non-hierarchical structure. It can be useful to also
    # have it so we included it here but it is not used.
    arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
    pprint(arg_dic)
    main(arg_dic)
Пример #7
0
def _train(args):
    train_dir = args.train
    val_dir = args.test

    with open('conf.yml') as f:
        def_conf = yaml.safe_load(f)

    pp = argparse.ArgumentParser()
    parser = prepare_parser_from_dict(def_conf, parser=pp)
    arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
    print(arg_dic)
    conf = arg_dic

    train_set = WhamDataset_no_sf(
        train_dir,
        conf['data']['task'],
        sample_rate=conf['data']['sample_rate'],
        segment=conf['data']['segment'],
        nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset_no_sf(
        val_dir,
        conf['data']['task'],
        segment=conf['data']['segment'],
        sample_rate=conf['data']['sample_rate'],
        nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)

    # train_loader = DataLoader(train_set, shuffle=True,
    #                           batch_size=args.batch_size,
    #                           num_workers=conf['training']['num_workers'],
    #                           drop_last=True)
    # val_loader = DataLoader(val_set, shuffle=False,
    #                         batch_size=args.batch_size,
    #                         num_workers=conf['training']['num_workers'],
    #                         drop_last=True)
    # Update number of source values (It depends on the task)
    print("!!!!!!!!!")
    print(train_set.__getitem__(0))
    print(val_set.__getitem__(0))
    print("!!!!!!!!!")
    conf['masknet'].update({'n_src': train_set.n_src})

    model = DPRNNTasNet(**conf['filterbank'], **conf['masknet'])
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    # exp_dir = conf['main_args']['exp_dir']
    # os.makedirs(exp_dir, exist_ok=True)
    exp_dir = args.model_dir
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)
    system.batch_size = 1

    # Define callbacks
    # checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    # checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
    #                              mode='min', save_top_k=5, verbose=1)
    # early_stopping = False
    # if conf['training']['early_stop']:
    #     early_stopping = EarlyStopping(monitor='val_loss', patience=10,
    #                                    verbose=1)

    # Don't ask GPU if they are not available.
    # print("!!!!!!!{}".format(torch.cuda.is_available()))
    # print(torch.__version__)
    gpus = -1 if torch.cuda.is_available() else None
    # trainer = pl.Trainer(max_epochs=conf['training']['epochs'],
    #                      checkpoint_callback=checkpoint,
    #                      early_stop_callback=early_stopping,
    #                      default_root_dir=exp_dir,
    #                      gpus=gpus,
    #                      distributed_backend='ddp',
    #                      gradient_clip_val=conf['training']["gradient_clipping"])
    trainer = pl.Trainer(
        max_epochs=args.epochs,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend='ddp',
        gradient_clip_val=conf['training']["gradient_clipping"])
    trainer.fit(system)
    # print("!!!!!!!!!!!!!!")
    # print(checkpoint)
    # print(checkpoint.best_k_models)
    # print(checkpoint.best_k_models.items())
    # onlyfiles = [f for f in listdir(checkpoint_dir) if isfile(os.path.join(checkpoint_dir, f))]
    # print(onlyfiles)

    # best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    # with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
    #     json.dump(best_k, f, indent=0)

    # # Save best model (next PL version will make this easier)
    # best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    best_path = os.path.join(exp_dir, "__temp_weight_ddp_end.ckpt")
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = system.model.serialize()
    # to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))