def test_with_file_cache(self): with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "cache") config = { "main": { "type": "with_file_cache", "path": path, "config": { "type": "select", "key": "key", "options": { "key": 0 } } } } result = parse_config(config) assert result["/main"] == 0 config = { "main": { "type": "with_file_cache", "path": path, "config": { "type": "select", "key": "key", "options": { "key": 1 } } } } result = parse_config(config) assert result["/main"] == 0
def launch(config_file: str, option: Optional[str], tmpdir: str, rank: Optional[int], n_process: Optional[int]): logging.set_level(L.INFO) logger.info(f"Launch config file: {config_file}") configs = load_config(config_file) if option == "test": logger.info("Modify configs for testing") configs = modify_config_for_test(configs, tmpdir) elif option == "profile": logger.info("Modify configs for profiling") output_dir = configs["output_dir"] # TODO configs = modify_config_for_profile(configs, tmpdir) distributed.initialize(tmpdir, rank, n_process) rank = distributed.rank() seed = random.randint(0, 2**31) + rank logger.info(f"Fix seed={seed}") rng = np.random.RandomState(seed) torch.manual_seed(rng.randint(0, 2**32 - 1)) np.random.seed(rng.randint(0, 2**32 - 1)) random.seed(rng.randint(0, 2**32 - 1)) logger.info("Run main") if option == "profile": cprofile = cProfile.Profile() cprofile.enable() with profile() as torch_prof: parse_config(configs)["/main"] cprofile.disable() torch.save(torch_prof, os.path.join(output_dir, f"torch_profiler-{rank}.pt")) cprofile.dump_stats(os.path.join(output_dir, f"cprofile-{rank}.pt")) else: parse_config(configs)["/main"]
def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--config_file", "-c", required=True, type=str) args = parser.parse_args() configs = load_config(args.config_file) with tempfile.TemporaryDirectory() as tmpdir: configs["base_output_dir"] = tmpdir configs["output_dir"] = os.path.join(tmpdir, "output") configs["device"]["type_str"] = "cpu" configs["train_dataset"] = configs["valid_dataset"] configs = parse_config(configs, custom_types=types) loader = torch.utils.data.DataLoader(configs["/train_dataset"], batch_size=1, collate_fn=configs["/collate_fn"]) for x in tqdm(loader): if x is None: print(x)
def test_custom_types(self): config = {"main": {"type": "foo"}} result = parse_config(config, custom_types={"foo": lambda: 10}) assert result["/main"] == 10