示例#1
0
 def test_with_file_cache(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         path = os.path.join(tmpdir, "cache")
         config = {
             "main": {
                 "type": "with_file_cache",
                 "path": path,
                 "config": {
                     "type": "select",
                     "key": "key",
                     "options": {
                         "key": 0
                     }
                 }
             }
         }
         result = parse_config(config)
         assert result["/main"] == 0
         config = {
             "main": {
                 "type": "with_file_cache",
                 "path": path,
                 "config": {
                     "type": "select",
                     "key": "key",
                     "options": {
                         "key": 1
                     }
                 }
             }
         }
         result = parse_config(config)
         assert result["/main"] == 0
示例#2
0
def launch(config_file: str, option: Optional[str], tmpdir: str,
           rank: Optional[int], n_process: Optional[int]):
    logging.set_level(L.INFO)

    logger.info(f"Launch config file: {config_file}")
    configs = load_config(config_file)
    if option == "test":
        logger.info("Modify configs for testing")
        configs = modify_config_for_test(configs, tmpdir)
    elif option == "profile":
        logger.info("Modify configs for profiling")
        output_dir = configs["output_dir"]  # TODO
        configs = modify_config_for_profile(configs, tmpdir)

    distributed.initialize(tmpdir, rank, n_process)

    rank = distributed.rank()
    seed = random.randint(0, 2**31) + rank
    logger.info(f"Fix seed={seed}")
    rng = np.random.RandomState(seed)
    torch.manual_seed(rng.randint(0, 2**32 - 1))
    np.random.seed(rng.randint(0, 2**32 - 1))
    random.seed(rng.randint(0, 2**32 - 1))

    logger.info("Run main")
    if option == "profile":
        cprofile = cProfile.Profile()
        cprofile.enable()
        with profile() as torch_prof:
            parse_config(configs)["/main"]
        cprofile.disable()
        torch.save(torch_prof,
                   os.path.join(output_dir, f"torch_profiler-{rank}.pt"))
        cprofile.dump_stats(os.path.join(output_dir, f"cprofile-{rank}.pt"))
    else:
        parse_config(configs)["/main"]
示例#3
0
def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_file", "-c", required=True, type=str)
    args = parser.parse_args()
    configs = load_config(args.config_file)
    with tempfile.TemporaryDirectory() as tmpdir:
        configs["base_output_dir"] = tmpdir
        configs["output_dir"] = os.path.join(tmpdir, "output")
        configs["device"]["type_str"] = "cpu"
        configs["train_dataset"] = configs["valid_dataset"]
        configs = parse_config(configs, custom_types=types)
        loader = torch.utils.data.DataLoader(configs["/train_dataset"],
                                             batch_size=1,
                                             collate_fn=configs["/collate_fn"])
        for x in tqdm(loader):
            if x is None:
                print(x)
示例#4
0
 def test_custom_types(self):
     config = {"main": {"type": "foo"}}
     result = parse_config(config, custom_types={"foo": lambda: 10})
     assert result["/main"] == 10