def test_dump_config(self, tmp_path: Path): run_dir = tmp_path / "runs" run_dir.mkdir(exist_ok=True, parents=True) path = Path("tests/testconfigs/test_config.yml") cfg = Config(cfg_path=path) cfg.run_dir = run_dir # check that defaults not specified are written to file # check that the file gets created cfg.dump_config(run_dir) assert "config.yml" in [l.name for l in run_dir.glob("*")] cfg_path = run_dir / "config.yml" with cfg_path.open("r") as fp: yaml = YAML(typ="safe") cfg2 = yaml.load(fp) expected_keys_with_defaults = [ "autoregressive", "pixel_dims", "num_workers", "seed", "device", "learning_rate", "time_str", "run_dir", ] for key in expected_keys_with_defaults: assert key in [l for l in cfg2.keys()]
def test_pollution(self, tmp_path): ds = get_pollution_data_beijing().to_xarray() cfg = Config(cfg_path=Path("tests/testconfigs/pollution.yml")) cfg.run_dir = tmp_path trainer = Trainer(cfg, ds) input_variables = [] if cfg.input_variables is None else cfg.input_variables train_ds = ds[input_variables + [cfg.target_variable]].sel( time=slice(cfg.train_start_date, cfg.train_end_date)) assert trainer.train_dl.dataset.lookup_table != {} assert trainer.train_dl.dataset.y != {} assert trainer.train_dl.dataset.x_d != {}
def test_train_test_split(self, tmp_path): ds = create_linear_ds().isel(lat=slice(0, 5), lon=slice(0, 5)) cfg = Config(Path("tests/testconfigs/test_config.yml")) cfg.run_dir = tmp_path train = train_test_split(ds, cfg, subset="train") test = train_test_split(ds, cfg, subset="test") valid = train_test_split(ds, cfg, subset="validation") cfg.train_start_date cfg.train_end_date cfg.validation_start_date cfg.validation_end_date cfg.test_start_date cfg.test_end_date
def test_tester(self, tmp_path): ds = create_linear_ds().isel(lat=slice(0, 5), lon=slice(0, 5)) cfg = Config(Path("tests/testconfigs/test_config.yml")) cfg._cfg["n_epochs"] = 1 cfg._cfg["num_workers"] = 1 cfg._cfg["horizon"] = 5 cfg.run_dir = tmp_path # initialise the train directory! trainer = Trainer(cfg, ds) trainer.train_and_validate() tester = Tester(cfg=cfg, ds=ds) # TODO: test the tester evaluation loop tester.run_test() # TODO: test that plots created, outputs saved outfile = sorted(list(cfg.run_dir.glob("*.nc")))[-1] out_ds = xr.open_dataset(outfile) assert int(out_ds.horizon.values) == cfg.horizon # Check that the times are correct min_time = pd.to_datetime(out_ds.time.values.min()).round("D") exp_min_time = cfg.test_start_date + DateOffset( months=(cfg.seq_length + cfg.horizon)) assert all([ (min_time.year == exp_min_time.year), (min_time.month == exp_min_time.month), (min_time.day == exp_min_time.day), ]) max_time = pd.to_datetime(out_ds.time.values.max()).round("D") exp_max_time = cfg.test_end_date - DateOffset(months=1) assert all([ (max_time.year == exp_max_time.year), (max_time.month == exp_max_time.month), (max_time.day == exp_max_time.day), ])
def create_and_assign_temp_run_path_to_config(cfg: Config, tmp_path: Path) -> None: # create run_dir (tmp_path / "runs").mkdir(exist_ok=True, parents=True) cfg.run_dir = tmp_path / "runs"
def test_trainer(self, tmp_path: Path): ds = create_linear_ds().isel(lat=slice(0, 5), lon=slice(0, 5)) cfg = Config(Path("tests/testconfigs/test_config.yml")) cfg.run_dir = tmp_path Trainer(cfg=cfg, ds=ds)