Beispiel #1
0
 def test_overwrite(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         with open(os.path.join(tmpdir, "sub.py"), "w") as file:
             file.writelines(["x = 0"])
         with open(os.path.join(tmpdir, "x.py"), "w") as file:
             file.writelines(["imports = [\"sub.py\"]\n" "x = 10"])
         assert {"x": 10} == load_config(os.path.join(tmpdir, "x.py"))
Beispiel #2
0
def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_file", "-c", required=True, type=str)
    args = parser.parse_args()
    configs = load_config(args.config_file)
    with tempfile.TemporaryDirectory() as tmpdir:
        configs["base_output_dir"] = tmpdir
        configs["output_dir"] = os.path.join(tmpdir, "output")
        configs["device"]["type_str"] = "cpu"
        configs["train_dataset"] = configs["valid_dataset"]
        configs = parse_config(configs, custom_types=types)
        loader = torch.utils.data.DataLoader(configs["/train_dataset"],
                                             batch_size=1,
                                             collate_fn=configs["/collate_fn"])
        for x in tqdm(loader):
            if x is None:
                print(x)
Beispiel #3
0
def launch(config_file: str, option: Optional[str], tmpdir: str,
           rank: Optional[int], n_process: Optional[int]):
    logging.set_level(L.INFO)

    logger.info(f"Launch config file: {config_file}")
    configs = load_config(config_file)
    if option == "test":
        logger.info("Modify configs for testing")
        configs = modify_config_for_test(configs, tmpdir)
    elif option == "profile":
        logger.info("Modify configs for profiling")
        output_dir = configs["output_dir"]  # TODO
        configs = modify_config_for_profile(configs, tmpdir)

    distributed.initialize(tmpdir, rank, n_process)

    rank = distributed.rank()
    seed = random.randint(0, 2**31) + rank
    logger.info(f"Fix seed={seed}")
    rng = np.random.RandomState(seed)
    torch.manual_seed(rng.randint(0, 2**32 - 1))
    np.random.seed(rng.randint(0, 2**32 - 1))
    random.seed(rng.randint(0, 2**32 - 1))

    logger.info("Run main")
    if option == "profile":
        cprofile = cProfile.Profile()
        cprofile.enable()
        with profile() as torch_prof:
            parse_config(configs)["/main"]
        cprofile.disable()
        torch.save(torch_prof,
                   os.path.join(output_dir, f"torch_profiler-{rank}.pt"))
        cprofile.dump_stats(os.path.join(output_dir, f"cprofile-{rank}.pt"))
    else:
        parse_config(configs)["/main"]
Beispiel #4
0
 def test_simple_case(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         with open(os.path.join(tmpdir, "x.py"), "w") as file:
             file.write("x = 10")
         assert {"x": 10} == load_config(os.path.join(tmpdir, "x.py"))