コード例 #1
0
    def test_train_test_split(self, tmp_path):
        ds = create_linear_ds().isel(lat=slice(0, 5), lon=slice(0, 5))
        cfg = Config(Path("tests/testconfigs/test_config.yml"))
        cfg.run_dir = tmp_path

        train = train_test_split(ds, cfg, subset="train")
        test = train_test_split(ds, cfg, subset="test")
        valid = train_test_split(ds, cfg, subset="validation")

        cfg.train_start_date
        cfg.train_end_date
        cfg.validation_start_date
        cfg.validation_end_date
        cfg.test_start_date
        cfg.test_end_date
コード例 #2
0
    def test_tester(self, tmp_path):
        ds = create_linear_ds().isel(lat=slice(0, 5), lon=slice(0, 5))
        cfg = Config(Path("tests/testconfigs/test_config.yml"))
        cfg._cfg["n_epochs"] = 1
        cfg._cfg["num_workers"] = 1
        cfg._cfg["horizon"] = 5
        cfg.run_dir = tmp_path

        # initialise the train directory!
        trainer = Trainer(cfg, ds)
        trainer.train_and_validate()

        tester = Tester(cfg=cfg, ds=ds)

        #  TODO: test the tester evaluation loop
        tester.run_test()
        #  TODO: test that plots created, outputs saved
        outfile = sorted(list(cfg.run_dir.glob("*.nc")))[-1]
        out_ds = xr.open_dataset(outfile)

        assert int(out_ds.horizon.values) == cfg.horizon

        #  Check that the times are correct
        min_time = pd.to_datetime(out_ds.time.values.min()).round("D")
        exp_min_time = cfg.test_start_date + DateOffset(
            months=(cfg.seq_length + cfg.horizon))

        assert all([
            (min_time.year == exp_min_time.year),
            (min_time.month == exp_min_time.month),
            (min_time.day == exp_min_time.day),
        ])

        max_time = pd.to_datetime(out_ds.time.values.max()).round("D")
        exp_max_time = cfg.test_end_date - DateOffset(months=1)

        assert all([
            (max_time.year == exp_max_time.year),
            (max_time.month == exp_max_time.month),
            (max_time.day == exp_max_time.day),
        ])
コード例 #3
0
    def test_linear_example(self):
        ds = create_linear_ds(epsilon_sigma=10)
        static_data = create_static_example_data(ds)

        cfg = Config(Path("tests/testconfigs/test_config.yml"))
        cfg._cfg["static_inputs"] = ["static_const", "static_rand"]

        #  Train
        trainer = Trainer(cfg, ds, static_data=static_data)
        self.check_loaded_data(
            cfg,
            trainer,
            data=ds.sel(time=slice(cfg.train_start_date, cfg.train_end_date)),
        )
        losses = trainer.train_and_validate()
        save_loss_curves(losses, cfg)

        # Test
        tester = Tester(cfg, ds, static_data=static_data)
        preds = tester.run_test()
        for _ in range(2):
            save_timeseries(preds, cfg)
コード例 #4
0
    def test_linear_example(self, tmp_path):
        """Test the linear dataset.

        Args:
            tmp_path ([type]): [description]
        """
        cfg = Config(Path("tests/testconfigs/test_config.yml"))
        create_and_assign_temp_run_path_to_config(cfg, tmp_path)

        #  Create linear dataset
        alpha = 0
        beta = 2
        epsilon_sigma = 0

        ds = create_linear_ds(
            horizon=cfg.horizon, alpha=alpha, beta=beta, epsilon_sigma=epsilon_sigma
        ).isel(lat=slice(0, 2), lon=slice(0, 2))
        static = create_static(cfg=cfg, ds=ds)
        dl = PixelDataLoader(
            ds,
            cfg=cfg,
            num_workers=1,
            mode="train",
            batch_size=cfg.batch_size,
            DEBUG=True,
            static_data=static,
        )

        #  load all of the data into memory
        data = load_all_data_from_dl_into_memory(dl)
        x = data["x_d"]

        # (n_samples, n_features, seq_length)
        assert x.shape == (
            len(cfg.input_variables) + 2
            if cfg.encode_doys
            else len(cfg.input_variables),
            cfg.seq_length,
        )
        assert x.shape[-1] == cfg.seq_length
        y = data["y"]
        times = pd.to_datetime(data["time"].astype("datetime64[ns]").flatten())

        # matching batch dims (n_samples) for all samples
        assert x.shape[0] == y.shape[0]

        #  test ONE SINGLE (x, y) sample
        SAMPLE = 1

        # get metadata for sample
        idx = int(data["index"][SAMPLE])
        pixel, valid_current_time_index = dl.dataset.lookup_table[idx]
        latlon = tuple([float(l) for l in str(pixel).split("_")])
        target_time = times[SAMPLE]
        # current_time = times[valid_current_time_index][0]

        #  get the correct times (weird indexing becuase of imperfect translation of float -> datetime64[ns])
        max_time = target_time - DateOffset(months=cfg.horizon) + DateOffset(days=2)
        min_time = max_time - DateOffset(months=cfg.seq_length)
        input_times = pd.date_range(min_time, max_time, freq="M")[-cfg.seq_length :]

        #  recreate the data that should be loaded from the raw xr.Dataset
        stacked, _ = _stack_xarray(ds, spatial_coords=cfg.pixel_dims)
        normalizer = dl.normalizer
        norm_stacked = normalizer.transform(stacked)

        all_y = norm_stacked["target"].sel(sample=pixel)
        _y = all_y.sel(time=target_time, method="nearest")
        all_x = norm_stacked["feature"].sel(sample=pixel)
        _x_d = all_x.sel(time=input_times, method="nearest")

        #  check that the dataloader saves & returns the correct values
        assert np.allclose(
            dl.dataset.y[pixel], (all_y.values)
        ), "The DataLoader saves incorrect y values to memory"
        assert np.isclose(
            _y.values, y[SAMPLE]
        ), "The DataLoader returns an incorrect value from the Dataset"

        #  input (X) data
        dataset_loaded = dl.dataset.x_d[pixel]
        # assert dataset_loaded.shape == (, cfg.seq_length)

        expected = all_x.values.reshape(dataset_loaded.shape)
        mask = np.isnan(expected)
        expected = expected[~mask]
        dataset_loaded = dataset_loaded[~mask]

        assert np.allclose(
            dataset_loaded, expected
        ), f"The dataloader is saving the wrong data to the lookup table. {dataset_loaded[:10]} {expected[:10]}"

        #  get input X data from INDEX (not times)
        max_input_ix = int(valid_current_time_index)
        min_input_ix = int(max_input_ix - cfg.seq_length) + 1
        _x_d_index_values = all_x.values[min_input_ix : max_input_ix + 1]

        assert np.allclose(_x_d_index_values, _x_d.values)

        # TODO: Why does this not work?
        assert np.allclose(
            _x_d_index_values.values, x[SAMPLE]
        ), "The dynamic data is not the data we expect"

        #  check that the raw data is the linear combination we expect
        # "target" should be linear combination of previous timestep "feature"
        # (y = x @ [0, 2])
        zeros = np.zeros((cfg.seq_length - 1, 1))
        betas = np.append(zeros, beta).reshape(-1, 1)
        unnorm_x = dl.dataset.normalizer.individual_inverse(
            x[SAMPLE], pixel_id=pixel, variable=cfg.input_variables[0]
        )
        unnorm_y = dl.dataset.normalizer.individual_inverse(
            y[SAMPLE], pixel_id=pixel, variable=cfg.target_variable
        )

        #  time=target_time,
        ds.sel(lat=latlon[0], lon=latlon[1], method="nearest")[cfg.target_variable]
        assert np.isclose(unnorm_x @ betas, unnorm_y)
コード例 #5
0
 def test_trainer(self, tmp_path: Path):
     ds = create_linear_ds().isel(lat=slice(0, 5), lon=slice(0, 5))
     cfg = Config(Path("tests/testconfigs/test_config.yml"))
     cfg.run_dir = tmp_path
     Trainer(cfg=cfg, ds=ds)
コード例 #6
0
from pathlib import Path
from spatio_temporal.config import Config
from spatio_temporal.training.trainer import Trainer
from spatio_temporal.training.tester import Tester
from tests.utils import create_linear_ds
from spatio_temporal.training.eval_utils import save_loss_curves, save_timeseries
from tqdm import tqdm

if __name__ == "__main__":
    #  EXPLICITLY write out training loop (good for debugging)
    ds = create_linear_ds(epsilon_sigma=10)
    cfg = Config(Path("tests/testconfigs/test_config.yml"))

    #  Train
    trainer = Trainer(cfg, ds)
    tester = Tester(cfg, ds)
    normalizer = trainer.train_dl.dataset.normalizer

    cfg._cfg["n_epochs"] = 2
    trainer.train_and_validate()
    preds = tester.run_test()

    assert False

    ## Test one loop
    #  Items for training loop
    model = trainer.model
    optimizer = trainer.optimizer
    loss_fn = trainer.loss_fn
    dl = trainer.train_dl