def test_batch_accessor_ds(sample_ds_3d): bg_class = BatchGenerator(sample_ds_3d, input_dims={'x': 5}) bg_acc = sample_ds_3d.batch.generator(input_dims={'x': 5}) assert isinstance(bg_acc, BatchGenerator) for batch_class, batch_acc in zip(bg_class, bg_acc): assert isinstance(batch_acc, xr.Dataset) assert batch_class.equals(batch_acc)
def test_batch_accessor_da(sample_ds_3d): sample_da = sample_ds_3d['foo'] bg_class = BatchGenerator(sample_da, input_dims={'x': 5}) bg_acc = sample_da.batch.generator(input_dims={'x': 5}) assert isinstance(bg_acc, BatchGenerator) for batch_class, batch_acc in zip(bg_class, bg_acc): assert batch_class.equals(batch_acc)
def test_batch_1d(sample_ds_1d, bsize): bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize}) for n, ds_batch in enumerate(bg): assert isinstance(ds_batch, xr.Dataset) # TODO: maybe relax this? see comment above assert ds_batch.dims['x'] == bsize expected_slice = slice(bsize * n, bsize * (n + 1)) ds_batch_expected = sample_ds_1d.isel(x=expected_slice) assert ds_batch.equals(ds_batch_expected)
def test_preload_batch_true(sample_ds_1d): sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2}) bg = BatchGenerator( sample_ds_1d_dask, input_dims={'x': 2}, preload_batch=True ) assert bg.preload_batch is True for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) assert not ds_batch.chunks
def test_batch_1d_concat(sample_ds_1d, bsize): bg = BatchGenerator( sample_ds_1d, input_dims={'x': bsize}, concat_input_dims=True ) for n, ds_batch in enumerate(bg): assert isinstance(ds_batch, xr.Dataset) assert ds_batch.dims['x_input'] == bsize assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x'] // bsize assert 'x' in ds_batch.coords
def test_batch_1d_no_coordinate(sample_ds_1d, bsize): # fix for #3 ds_dropped = sample_ds_1d.drop_vars('x') bg = BatchGenerator(ds_dropped, input_dims={'x': bsize}) for n, ds_batch in enumerate(bg): assert isinstance(ds_batch, xr.Dataset) assert ds_batch.dims['x'] == bsize expected_slice = slice(bsize * n, bsize * (n + 1)) ds_batch_expected = ds_dropped.isel(x=expected_slice) assert ds_batch.equals(ds_batch_expected)
def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize): # test for #3 ds_dropped = sample_ds_1d.drop_vars('x') bg = BatchGenerator( ds_dropped, input_dims={'x': bsize}, concat_input_dims=True ) for n, ds_batch in enumerate(bg): assert isinstance(ds_batch, xr.Dataset) assert ds_batch.dims['x_input'] == bsize assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x'] // bsize assert 'x' not in ds_batch.coords
def test_batch_3d_2d_input(sample_ds_3d, bsize): # now iterate over both xbsize = 20 bg = BatchGenerator(sample_ds_3d, input_dims={'y': bsize, 'x': xbsize}) for n, ds_batch in enumerate(bg): assert isinstance(ds_batch, xr.Dataset) assert ds_batch.dims['x'] == xbsize assert ds_batch.dims['y'] == bsize # TODO? Is it worth it to try to reproduce the internal logic of the # generator and verify that the slices are correct? assert (n+1)==((sample_ds_3d.dims['x']//xbsize) * (sample_ds_3d.dims['y']//bsize))
def test_batch_1d_overlap(sample_ds_1d, olap): bsize = 10 bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize}, input_overlap={'x': olap}) stride = bsize-olap for n, ds_batch in enumerate(bg): assert isinstance(ds_batch, xr.Dataset) assert ds_batch.dims['x'] == bsize expected_slice = slice(stride*n, stride*n + bsize) ds_batch_expected = sample_ds_1d.isel(x=expected_slice) assert ds_batch.equals(ds_batch_expected)
def test_batch_3d_2d_input_concat(sample_ds_3d, bsize): # now iterate over both xbsize = 20 bg = BatchGenerator(sample_ds_3d, input_dims={'y': bsize, 'x': xbsize}, concat_input_dims=True) for n, ds_batch in enumerate(bg): assert isinstance(ds_batch, xr.Dataset) assert ds_batch.dims['x_input'] == xbsize assert ds_batch.dims['y_input'] == bsize assert ds_batch.dims['sample'] == ((sample_ds_3d.dims['x']//xbsize) * (sample_ds_3d.dims['y']//bsize) * sample_ds_3d.dims['time'])
def test_iterable_dataset(ds_xy): x = ds_xy['x'] y = ds_xy['y'] x_gen = BatchGenerator(x, {'sample': 10}) y_gen = BatchGenerator(y, {'sample': 10}) dataset = IterableDataset(x_gen, y_gen) # test integration with torch DataLoader loader = torch.utils.data.DataLoader(dataset) for x_batch, y_batch in loader: assert len(x_batch) == len(y_batch) assert isinstance(x_batch, torch.Tensor) # TODO: why does pytorch add an extra dimension (length 1) to x_batch assert x_gen[-1]['x'].shape == x_batch.shape[1:] # TODO: also need to revisit the variable extraction bits here assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
def test_map_dataset(ds_xy): x = ds_xy['x'] y = ds_xy['y'] x_gen = BatchGenerator(x, {'sample': 10}) y_gen = BatchGenerator(y, {'sample': 10}) dataset = MapDataset(x_gen, y_gen) # test __getitem__ x_batch, y_batch = dataset[0] assert len(x_batch) == len(y_batch) assert isinstance(x_batch, torch.Tensor) idx = torch.tensor([0]) x_batch, y_batch = dataset[idx] assert len(x_batch) == len(y_batch) assert isinstance(x_batch, torch.Tensor) with pytest.raises(NotImplementedError): idx = torch.tensor([0, 1]) x_batch, y_batch = dataset[idx] # test __len__ assert len(dataset) == len(x_gen) # test integration with torch DataLoader loader = torch.utils.data.DataLoader(dataset) for x_batch, y_batch in loader: assert len(x_batch) == len(y_batch) assert isinstance(x_batch, torch.Tensor) # TODO: why does pytorch add an extra dimension (length 1) to x_batch assert x_gen[-1]['x'].shape == x_batch.shape[1:] # TODO: also need to revisit the variable extraction bits here assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
def test_batcher_getitem(sample_ds_1d): bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10}) # first batch assert bg[0].dims['x'] == 10 # last batch assert bg[-1].dims['x'] == 10 # raises IndexError for out of range index with pytest.raises(IndexError, match=r'list index out of range'): bg[9999999] # raises NotImplementedError for iterable index with pytest.raises(NotImplementedError): bg[[1, 2, 3]]
def test_map_dataset_with_transform(ds_xy): x = ds_xy['x'] y = ds_xy['y'] x_gen = BatchGenerator(x, {'sample': 10}) y_gen = BatchGenerator(y, {'sample': 10}) def x_transform(batch): return batch * 0 + 1 def y_transform(batch): return batch * 0 - 1 dataset = MapDataset(x_gen, y_gen, transform=x_transform, target_transform=y_transform) x_batch, y_batch = dataset[0] assert len(x_batch) == len(y_batch) assert isinstance(x_batch, torch.Tensor) assert (x_batch == 1).all() assert (y_batch == -1).all()
def test_batch_3d_1d_input(sample_ds_3d, bsize): # first do the iteration over just one dimension bg = BatchGenerator(sample_ds_3d, input_dims={'x': bsize}) for n, ds_batch in enumerate(bg): assert isinstance(ds_batch, xr.Dataset) assert ds_batch.dims['x'] == bsize # time and y should be collapsed into batch dimension assert ds_batch.dims['sample'] == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time'] expected_slice = slice(bsize*n, bsize*(n+1)) ds_batch_expected = (sample_ds_3d.isel(x=expected_slice) .stack(sample=['time', 'y']) .transpose('sample', 'x')) print(ds_batch) print(ds_batch_expected) assert ds_batch.equals(ds_batch_expected)
def test_constructor_coerces_to_dataset(): da = xr.DataArray(np.random.rand(10), dims='x', name='foo') bg = BatchGenerator(da, input_dims={'x': 2}) assert isinstance(bg.ds, xr.Dataset) assert bg.ds.equals(da.to_dataset())
def test_batcher_lenth(sample_ds_1d, bsize): bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize}) assert len(bg) == sample_ds_1d.dims['x'] // bsize