Beispiel #1
0
    def test_dump_config(self, tmp_path: Path):
        run_dir = tmp_path / "runs"
        run_dir.mkdir(exist_ok=True, parents=True)
        path = Path("tests/testconfigs/test_config.yml")

        cfg = Config(cfg_path=path)
        cfg.run_dir = run_dir

        #  check that defaults not specified are written to file
        # check that the file gets created
        cfg.dump_config(run_dir)

        assert "config.yml" in [l.name for l in run_dir.glob("*")]

        cfg_path = run_dir / "config.yml"
        with cfg_path.open("r") as fp:
            yaml = YAML(typ="safe")
            cfg2 = yaml.load(fp)

        expected_keys_with_defaults = [
            "autoregressive",
            "pixel_dims",
            "num_workers",
            "seed",
            "device",
            "learning_rate",
            "time_str",
            "run_dir",
        ]
        for key in expected_keys_with_defaults:
            assert key in [l for l in cfg2.keys()]
    def test_pollution(self, tmp_path):
        ds = get_pollution_data_beijing().to_xarray()
        cfg = Config(cfg_path=Path("tests/testconfigs/pollution.yml"))
        cfg.run_dir = tmp_path
        trainer = Trainer(cfg, ds)

        input_variables = [] if cfg.input_variables is None else cfg.input_variables
        train_ds = ds[input_variables + [cfg.target_variable]].sel(
            time=slice(cfg.train_start_date, cfg.train_end_date))

        assert trainer.train_dl.dataset.lookup_table != {}
        assert trainer.train_dl.dataset.y != {}
        assert trainer.train_dl.dataset.x_d != {}
    def test_train_test_split(self, tmp_path):
        ds = create_linear_ds().isel(lat=slice(0, 5), lon=slice(0, 5))
        cfg = Config(Path("tests/testconfigs/test_config.yml"))
        cfg.run_dir = tmp_path

        train = train_test_split(ds, cfg, subset="train")
        test = train_test_split(ds, cfg, subset="test")
        valid = train_test_split(ds, cfg, subset="validation")

        cfg.train_start_date
        cfg.train_end_date
        cfg.validation_start_date
        cfg.validation_end_date
        cfg.test_start_date
        cfg.test_end_date
    def test_tester(self, tmp_path):
        ds = create_linear_ds().isel(lat=slice(0, 5), lon=slice(0, 5))
        cfg = Config(Path("tests/testconfigs/test_config.yml"))
        cfg._cfg["n_epochs"] = 1
        cfg._cfg["num_workers"] = 1
        cfg._cfg["horizon"] = 5
        cfg.run_dir = tmp_path

        # initialise the train directory!
        trainer = Trainer(cfg, ds)
        trainer.train_and_validate()

        tester = Tester(cfg=cfg, ds=ds)

        #  TODO: test the tester evaluation loop
        tester.run_test()
        #  TODO: test that plots created, outputs saved
        outfile = sorted(list(cfg.run_dir.glob("*.nc")))[-1]
        out_ds = xr.open_dataset(outfile)

        assert int(out_ds.horizon.values) == cfg.horizon

        #  Check that the times are correct
        min_time = pd.to_datetime(out_ds.time.values.min()).round("D")
        exp_min_time = cfg.test_start_date + DateOffset(
            months=(cfg.seq_length + cfg.horizon))

        assert all([
            (min_time.year == exp_min_time.year),
            (min_time.month == exp_min_time.month),
            (min_time.day == exp_min_time.day),
        ])

        max_time = pd.to_datetime(out_ds.time.values.max()).round("D")
        exp_max_time = cfg.test_end_date - DateOffset(months=1)

        assert all([
            (max_time.year == exp_max_time.year),
            (max_time.month == exp_max_time.month),
            (max_time.day == exp_max_time.day),
        ])
Beispiel #5
0
def create_and_assign_temp_run_path_to_config(cfg: Config,
                                              tmp_path: Path) -> None:
    # create run_dir
    (tmp_path / "runs").mkdir(exist_ok=True, parents=True)
    cfg.run_dir = tmp_path / "runs"
 def test_trainer(self, tmp_path: Path):
     ds = create_linear_ds().isel(lat=slice(0, 5), lon=slice(0, 5))
     cfg = Config(Path("tests/testconfigs/test_config.yml"))
     cfg.run_dir = tmp_path
     Trainer(cfg=cfg, ds=ds)