Ejemplo n.º 1
0
def test_batch_accessor_ds(sample_ds_3d):
    bg_class = BatchGenerator(sample_ds_3d, input_dims={'x': 5})
    bg_acc = sample_ds_3d.batch.generator(input_dims={'x': 5})
    assert isinstance(bg_acc, BatchGenerator)
    for batch_class, batch_acc in zip(bg_class, bg_acc):
        assert isinstance(batch_acc, xr.Dataset)
        assert batch_class.equals(batch_acc)
Ejemplo n.º 2
0
def test_batch_accessor_da(sample_ds_3d):
    sample_da = sample_ds_3d['foo']
    bg_class = BatchGenerator(sample_da, input_dims={'x': 5})
    bg_acc = sample_da.batch.generator(input_dims={'x': 5})
    assert isinstance(bg_acc, BatchGenerator)
    for batch_class, batch_acc in zip(bg_class, bg_acc):
        assert batch_class.equals(batch_acc)
Ejemplo n.º 3
0
def test_batch_1d(sample_ds_1d, bsize):
    bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize})
    for n, ds_batch in enumerate(bg):
        assert isinstance(ds_batch, xr.Dataset)
        # TODO: maybe relax this? see comment above
        assert ds_batch.dims['x'] == bsize
        expected_slice = slice(bsize * n, bsize * (n + 1))
        ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
        assert ds_batch.equals(ds_batch_expected)
Ejemplo n.º 4
0
def test_preload_batch_true(sample_ds_1d):
    sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2})
    bg = BatchGenerator(
        sample_ds_1d_dask, input_dims={'x': 2}, preload_batch=True
    )
    assert bg.preload_batch is True
    for ds_batch in bg:
        assert isinstance(ds_batch, xr.Dataset)
        assert not ds_batch.chunks
Ejemplo n.º 5
0
def test_batch_1d_concat(sample_ds_1d, bsize):
    bg = BatchGenerator(
        sample_ds_1d, input_dims={'x': bsize}, concat_input_dims=True
    )
    for n, ds_batch in enumerate(bg):
        assert isinstance(ds_batch, xr.Dataset)
        assert ds_batch.dims['x_input'] == bsize
        assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x'] // bsize
        assert 'x' in ds_batch.coords
Ejemplo n.º 6
0
def test_batch_1d_no_coordinate(sample_ds_1d, bsize):
    # fix for #3
    ds_dropped = sample_ds_1d.drop_vars('x')
    bg = BatchGenerator(ds_dropped, input_dims={'x': bsize})
    for n, ds_batch in enumerate(bg):
        assert isinstance(ds_batch, xr.Dataset)
        assert ds_batch.dims['x'] == bsize
        expected_slice = slice(bsize * n, bsize * (n + 1))
        ds_batch_expected = ds_dropped.isel(x=expected_slice)
        assert ds_batch.equals(ds_batch_expected)
Ejemplo n.º 7
0
def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize):
    # test for #3
    ds_dropped = sample_ds_1d.drop_vars('x')
    bg = BatchGenerator(
        ds_dropped, input_dims={'x': bsize}, concat_input_dims=True
    )
    for n, ds_batch in enumerate(bg):
        assert isinstance(ds_batch, xr.Dataset)
        assert ds_batch.dims['x_input'] == bsize
        assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x'] // bsize
        assert 'x' not in ds_batch.coords
Ejemplo n.º 8
0
def test_batch_3d_2d_input(sample_ds_3d, bsize):
    # now iterate over both
    xbsize = 20
    bg = BatchGenerator(sample_ds_3d, input_dims={'y': bsize, 'x': xbsize})
    for n, ds_batch in enumerate(bg):
        assert isinstance(ds_batch, xr.Dataset)
        assert ds_batch.dims['x'] == xbsize
        assert ds_batch.dims['y'] == bsize
        # TODO? Is it worth it to try to reproduce the internal logic of the
        # generator and verify that the slices are correct?
    assert (n+1)==((sample_ds_3d.dims['x']//xbsize) * (sample_ds_3d.dims['y']//bsize))
Ejemplo n.º 9
0
def test_batch_1d_overlap(sample_ds_1d, olap):
    bsize = 10
    bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize},
                        input_overlap={'x': olap})
    stride = bsize-olap
    for n, ds_batch in enumerate(bg):
        assert isinstance(ds_batch, xr.Dataset)
        assert ds_batch.dims['x'] == bsize
        expected_slice = slice(stride*n, stride*n + bsize)
        ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
        assert ds_batch.equals(ds_batch_expected)
Ejemplo n.º 10
0
def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
    # now iterate over both
    xbsize = 20
    bg = BatchGenerator(sample_ds_3d, input_dims={'y': bsize, 'x': xbsize},
                        concat_input_dims=True)
    for n, ds_batch in enumerate(bg):
        assert isinstance(ds_batch, xr.Dataset)
        assert ds_batch.dims['x_input'] == xbsize
        assert ds_batch.dims['y_input'] == bsize
        assert ds_batch.dims['sample'] == ((sample_ds_3d.dims['x']//xbsize) *
                                          (sample_ds_3d.dims['y']//bsize) *
                                          sample_ds_3d.dims['time'])
Ejemplo n.º 11
0
def test_iterable_dataset(ds_xy):

    x = ds_xy['x']
    y = ds_xy['y']

    x_gen = BatchGenerator(x, {'sample': 10})
    y_gen = BatchGenerator(y, {'sample': 10})

    dataset = IterableDataset(x_gen, y_gen)

    # test integration with torch DataLoader
    loader = torch.utils.data.DataLoader(dataset)

    for x_batch, y_batch in loader:
        assert len(x_batch) == len(y_batch)
        assert isinstance(x_batch, torch.Tensor)

    # TODO: why does pytorch add an extra dimension (length 1) to x_batch
    assert x_gen[-1]['x'].shape == x_batch.shape[1:]
    # TODO: also need to revisit the variable extraction bits here
    assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
Ejemplo n.º 12
0
def test_map_dataset(ds_xy):

    x = ds_xy['x']
    y = ds_xy['y']

    x_gen = BatchGenerator(x, {'sample': 10})
    y_gen = BatchGenerator(y, {'sample': 10})

    dataset = MapDataset(x_gen, y_gen)

    # test __getitem__
    x_batch, y_batch = dataset[0]
    assert len(x_batch) == len(y_batch)
    assert isinstance(x_batch, torch.Tensor)

    idx = torch.tensor([0])
    x_batch, y_batch = dataset[idx]
    assert len(x_batch) == len(y_batch)
    assert isinstance(x_batch, torch.Tensor)

    with pytest.raises(NotImplementedError):
        idx = torch.tensor([0, 1])
        x_batch, y_batch = dataset[idx]

    # test __len__
    assert len(dataset) == len(x_gen)

    # test integration with torch DataLoader
    loader = torch.utils.data.DataLoader(dataset)

    for x_batch, y_batch in loader:
        assert len(x_batch) == len(y_batch)
        assert isinstance(x_batch, torch.Tensor)

    # TODO: why does pytorch add an extra dimension (length 1) to x_batch
    assert x_gen[-1]['x'].shape == x_batch.shape[1:]
    # TODO: also need to revisit the variable extraction bits here
    assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
Ejemplo n.º 13
0
def test_batcher_getitem(sample_ds_1d):
    bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10})

    # first batch
    assert bg[0].dims['x'] == 10
    # last batch
    assert bg[-1].dims['x'] == 10
    # raises IndexError for out of range index
    with pytest.raises(IndexError, match=r'list index out of range'):
        bg[9999999]

    # raises NotImplementedError for iterable index
    with pytest.raises(NotImplementedError):
        bg[[1, 2, 3]]
Ejemplo n.º 14
0
def test_map_dataset_with_transform(ds_xy):

    x = ds_xy['x']
    y = ds_xy['y']

    x_gen = BatchGenerator(x, {'sample': 10})
    y_gen = BatchGenerator(y, {'sample': 10})

    def x_transform(batch):
        return batch * 0 + 1

    def y_transform(batch):
        return batch * 0 - 1

    dataset = MapDataset(x_gen,
                         y_gen,
                         transform=x_transform,
                         target_transform=y_transform)
    x_batch, y_batch = dataset[0]
    assert len(x_batch) == len(y_batch)
    assert isinstance(x_batch, torch.Tensor)
    assert (x_batch == 1).all()
    assert (y_batch == -1).all()
Ejemplo n.º 15
0
def test_batch_3d_1d_input(sample_ds_3d, bsize):

    # first do the iteration over just one dimension
    bg = BatchGenerator(sample_ds_3d, input_dims={'x': bsize})
    for n, ds_batch in enumerate(bg):
        assert isinstance(ds_batch, xr.Dataset)
        assert ds_batch.dims['x'] == bsize
        # time and y should be collapsed into batch dimension
        assert ds_batch.dims['sample'] == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time']
        expected_slice = slice(bsize*n, bsize*(n+1))
        ds_batch_expected = (sample_ds_3d.isel(x=expected_slice)
                                         .stack(sample=['time', 'y'])
                                         .transpose('sample', 'x'))
        print(ds_batch)
        print(ds_batch_expected)
        assert ds_batch.equals(ds_batch_expected)
Ejemplo n.º 16
0
def test_constructor_coerces_to_dataset():
    da = xr.DataArray(np.random.rand(10), dims='x', name='foo')
    bg = BatchGenerator(da, input_dims={'x': 2})
    assert isinstance(bg.ds, xr.Dataset)
    assert bg.ds.equals(da.to_dataset())
Ejemplo n.º 17
0
def test_batcher_lenth(sample_ds_1d, bsize):
    bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize})
    assert len(bg) == sample_ds_1d.dims['x'] // bsize