def test__batches_from_mapper_invalid_times(mapper): invalid_times = list( mapper.keys())[:2] + ["20000101.000000", "20000102.000000"] with pytest.raises(ValueError): batches_from_mapper( mapper, DATA_VARS, timesteps_per_batch=2, timesteps=invalid_times, needs_grid=False, )
def test_diagnostic_batches_from_mapper(mapper): batched_data_sequence = batches_from_mapper( mapper, DATA_VARS, timesteps_per_batch=2, needs_grid=False, ) assert len(batched_data_sequence) == len(mapper) // 2 + len(mapper) % 2 for i, batch in enumerate(batched_data_sequence): assert len(batch["z"]) == Z_DIM_SIZE assert set(batch.data_vars) == set(DATA_VARS)
def test_batches_from_mapper_timestep_list(mapper, total_times, times_per_batch, valid_num_batches): timestep_list = list(mapper.keys())[:total_times] batched_data_sequence = batches_from_mapper( mapper, DATA_VARS, timesteps_per_batch=times_per_batch, timesteps=timestep_list, needs_grid=False, ) assert len(batched_data_sequence) == valid_num_batches timesteps_used = sum(batched_data_sequence._args, ()) # flattens list assert set(timesteps_used).issubset(timestep_list)
def test_batches_from_mappper_different_indexing_conventions(tiles): n = 48 ds = xr.Dataset( {"a": (["time", "tile", "y", "x"], np.zeros((1, 6, n, n)))}, coords={ "time": [cftime.DatetimeJulian(2016, 8, 1)], "tile": tiles }, ) mapper = loaders.mappers.XarrayMapper(ds) seq = batches_from_mapper(mapper, ["a", "lon"], res=f"c{n}") assert len(seq) == 1 assert ds.a[0].size == seq[0].a.size
def test_batches_from_mapper(mapper): batched_data_sequence = batches_from_mapper( mapper, DATA_VARS, timesteps_per_batch=2, needs_grid=False, ) ds = batched_data_sequence[0] original_data_dims = {name: ds[name].dims for name in ds} original_dim_lengths = {dim: len(dim) for dim in ds.dims} assert len(batched_data_sequence) == 2 for i, batch in enumerate(batched_data_sequence): assert len(batch["z"]) == Z_DIM_SIZE assert set(batch.data_vars) == set(DATA_VARS) for name in batch.data_vars.keys(): assert set(batch[name].dims) == set(original_data_dims[name]) for dim in batch.dims: assert len(dim) == original_dim_lengths[dim]