def test_boolean(parser): fake_args = ["--key3", "y"] arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True, args=fake_args) assert plain_args.key3 is True fake_args = ["--key3", "n"] arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True, args=fake_args) assert plain_args.key3 is False
def test_boolean(parser): fake_args = ['--key3', 'y'] arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True, args=fake_args) assert plain_args.key3 is True fake_args = ['--key3', 'n'] arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True, args=fake_args) assert plain_args.key3 is False
def test_namespace_dic(parser): fake_args = ['--key2', 'hey', '--key3', '0'] arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True, args=fake_args) assert arg_dic['main_args']['main_key'] == plain_args.main_key assert arg_dic['top2']['key3'] == plain_args.key3
def test_namespace_dic(parser): fake_args = ["--key2", "hey", "--key3", "0"] arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True, args=fake_args) assert arg_dic["main_args"]["main_key"] == plain_args.main_key assert arg_dic["top2"]["key3"] == plain_args.key3
def test_none_default(parser, inp): # If the default is None, convert the input string into an int, a float # or string. fake_args = ['--key2', str(inp)] # Note : inp is converted to string arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True, args=fake_args) assert type(plain_args.key2) == type(inp)
state_dict = torch.load(checkpoint.best_model_path) system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, "best_model.pth")) if __name__ == "__main__": import yaml from pprint import pprint from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict # We start with opening the config file conf.yml as a dictionary from # which we can create parsers. Each top level key in the dictionary defined # by the YAML file creates a group in the parser. with open("local/conf.yml") as f: def_conf = yaml.safe_load(f) parser = prepare_parser_from_dict(def_conf, parser=parser) # Arguments are then parsed into a hierarchical dictionary (instead of # flat, as returned by argparse) to facilitate calls to the different # asteroid methods (see in main). # plain_args is the direct output of parser.parse_args() and contains all # the attributes in an non-hierarchical structure. It can be useful to also # have it so we included it here but it is not used. arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) pprint(arg_dic) main(arg_dic)
def _train(args): train_dir = args.train val_dir = args.test with open('conf.yml') as f: def_conf = yaml.safe_load(f) pp = argparse.ArgumentParser() parser = prepare_parser_from_dict(def_conf, parser=pp) arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) print(arg_dic) conf = arg_dic train_set = WhamDataset_no_sf( train_dir, conf['data']['task'], sample_rate=conf['data']['sample_rate'], segment=conf['data']['segment'], nondefault_nsrc=conf['data']['nondefault_nsrc']) val_set = WhamDataset_no_sf( val_dir, conf['data']['task'], segment=conf['data']['segment'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) train_loader = DataLoader(train_set, shuffle=True, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) val_loader = DataLoader(val_set, shuffle=False, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) # train_loader = DataLoader(train_set, shuffle=True, # batch_size=args.batch_size, # num_workers=conf['training']['num_workers'], # drop_last=True) # val_loader = DataLoader(val_set, shuffle=False, # batch_size=args.batch_size, # num_workers=conf['training']['num_workers'], # drop_last=True) # Update number of source values (It depends on the task) print("!!!!!!!!!") print(train_set.__getitem__(0)) print(val_set.__getitem__(0)) print("!!!!!!!!!") conf['masknet'].update({'n_src': train_set.n_src}) model = DPRNNTasNet(**conf['filterbank'], **conf['masknet']) optimizer = make_optimizer(model.parameters(), **conf['optim']) # Define scheduler scheduler = None if conf['training']['half_lr']: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. # exp_dir = conf['main_args']['exp_dir'] # os.makedirs(exp_dir, exist_ok=True) exp_dir = args.model_dir conf_path = os.path.join(exp_dir, 'conf.yml') with open(conf_path, 'w') as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') system = System(model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf) system.batch_size = 1 # Define callbacks # checkpoint_dir = os.path.join(exp_dir, 'checkpoints/') # checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', # mode='min', save_top_k=5, verbose=1) # early_stopping = False # if conf['training']['early_stop']: # early_stopping = EarlyStopping(monitor='val_loss', patience=10, # verbose=1) # Don't ask GPU if they are not available. # print("!!!!!!!{}".format(torch.cuda.is_available())) # print(torch.__version__) gpus = -1 if torch.cuda.is_available() else None # trainer = pl.Trainer(max_epochs=conf['training']['epochs'], # checkpoint_callback=checkpoint, # early_stop_callback=early_stopping, # default_root_dir=exp_dir, # gpus=gpus, # distributed_backend='ddp', # gradient_clip_val=conf['training']["gradient_clipping"]) trainer = pl.Trainer( max_epochs=args.epochs, default_root_dir=exp_dir, gpus=gpus, distributed_backend='ddp', gradient_clip_val=conf['training']["gradient_clipping"]) trainer.fit(system) # print("!!!!!!!!!!!!!!") # print(checkpoint) # print(checkpoint.best_k_models) # print(checkpoint.best_k_models.items()) # onlyfiles = [f for f in listdir(checkpoint_dir) if isfile(os.path.join(checkpoint_dir, f))] # print(onlyfiles) # best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} # with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: # json.dump(best_k, f, indent=0) # # Save best model (next PL version will make this easier) # best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0] best_path = os.path.join(exp_dir, "__temp_weight_ddp_end.ckpt") state_dict = torch.load(best_path) system.load_state_dict(state_dict=state_dict['state_dict']) system.cpu() to_save = system.model.serialize() # to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))