def test_dask_array_creates_cache(): """Test that dask arrays create cache but turns off fusion.""" resize_dask_cache(1) assert _dask_utils._DASK_CACHE.cache.available_bytes == 1 # by default we have no dask_cache and task fusion is active original = dask.config.get("optimization.fuse.active", None) def mock_set_view_slice(): assert dask.config.get("optimization.fuse.active") is False layer = layers.Image(da.ones((100, 100))) layer._set_view_slice = mock_set_view_slice layer.set_view_slice() # adding a dask array will reate cache and turn off task fusion, # *but only* during slicing (see "mock_set_view_slice" above) assert _dask_utils._DASK_CACHE.cache.available_bytes > 100 assert not _dask_utils._DASK_CACHE.active assert dask.config.get("optimization.fuse.active", None) == original # make sure we can resize the cache resize_dask_cache(10000) assert _dask_utils._DASK_CACHE.cache.available_bytes == 10000 # This should only affect dask arrays, and not numpy data def mock_set_view_slice2(): assert dask.config.get("optimization.fuse.active", None) == original layer2 = layers.Image(np.ones((100, 100))) layer2._set_view_slice = mock_set_view_slice2 layer2.set_view_slice()
def test_dask_optimized_slicing(delayed_dask_stack, monkeypatch): """Test that dask_configure reduces compute with dask stacks.""" # make sure we have a cache # big enough for 10+ (10, 10, 10) "timepoints" utils.resize_dask_cache(100000) # add dask stack to the viewer, making sure to pass multiscale and clims v = viewer.ViewerModel() dask_stack = delayed_dask_stack['stack'] v.add_image(dask_stack, multiscale=False, contrast_limits=(0, 1)) assert delayed_dask_stack['calls'] == 1 # the first stack will be loaded # changing the Z plane should never incur calls # since the stack has already been loaded (& it is chunked as a 3D array) for i in range(3): v.dims.set_point(1, i) assert delayed_dask_stack['calls'] == 1 # still just the first call # changing the timepoint will, of course, incur some compute calls v.dims.set_point(0, 1) assert delayed_dask_stack['calls'] == 2 v.dims.set_point(0, 2) assert delayed_dask_stack['calls'] == 3 # but going back to previous timepoints should not, since they are cached v.dims.set_point(0, 1) v.dims.set_point(0, 0) assert delayed_dask_stack['calls'] == 3 v.dims.set_point(0, 3) assert delayed_dask_stack['calls'] == 4
def test_list_of_dask_arrays_doesnt_create_cache(): """Test that adding a list of dask array also creates a dask cache.""" resize_dask_cache(1) # in case other tests created it assert _dask_utils._DASK_CACHE.cache.available_bytes == 1 original = dask.config.get("optimization.fuse.active", None) _ = layers.Image([da.ones((100, 100)), da.ones((20, 20))]) assert _dask_utils._DASK_CACHE.cache.available_bytes > 100 assert not _dask_utils._DASK_CACHE.active assert dask.config.get("optimization.fuse.active", None) == original
def test_prevent_dask_cache(delayed_dask_stack): """Test that pre-emptively setting cache to zero keeps it off""" resize_dask_cache(0) v = ViewerModel() dask_stack = delayed_dask_stack['stack'] # adding a new stack will not increase the cache size v.add_image(dask_stack) assert _dask_utils._DASK_CACHE.cache.available_bytes == 0 # and the cache will not be populated for i in range(3): v.dims.set_point(0, i) assert len(_dask_utils._DASK_CACHE.cache.heap.heap) == 0
def test_prevent_dask_cache(delayed_dask_stack): """Test that pre-emptively setting cache to zero keeps it off""" # the del is not required, it just shows that prior state of the cache # does not matter... calling resize_dask_cache(0) will permanently disable del utils.dask_cache utils.resize_dask_cache(0) v = viewer.ViewerModel() dask_stack = delayed_dask_stack['stack'] # adding a new stack will not increase the cache size v.add_image(dask_stack, multiscale=False, contrast_limits=(0, 1)) assert utils.dask_cache.cache.available_bytes == 0 # and the cache will not be populated for i in range(3): v.dims.set_point(0, i) assert len(utils.dask_cache.cache.heap.heap) == 0
def test_dask_cache_resizing(delayed_dask_stack): """Test that we can spin up, resize, and spin down the cache.""" # add dask stack to the viewer, making sure to pass multiscale and clims utils.dask_cache = None v = viewer.ViewerModel() dask_stack = delayed_dask_stack['stack'] # adding a new stack should spin up a cache v.add_image(dask_stack, multiscale=False, contrast_limits=(0, 1)) assert utils.dask_cache.cache.available_bytes > 0 # make sure the cache actually has been populated assert len(utils.dask_cache.cache.heap.heap) > 0 # we can resize that cache back to 0 bytes utils.resize_dask_cache(0) assert utils.dask_cache.cache.available_bytes == 0 # adding a 2nd stack should not adjust the cache size once created v.add_image(dask_stack, multiscale=False, contrast_limits=(0, 1)) assert utils.dask_cache.cache.available_bytes == 0 # and the cache will remain empty regardless of what we do for i in range(3): v.dims.set_point(1, i) assert len(utils.dask_cache.cache.heap.heap) == 0 # but we can always spin it up again utils.resize_dask_cache(1e4) assert utils.dask_cache.cache.available_bytes == 1e4 # and adding a new image doesn't change the size v.add_image(dask_stack, multiscale=False, contrast_limits=(0, 1)) assert utils.dask_cache.cache.available_bytes == 1e4 # but the cache heap is getting populated again for i in range(3): v.dims.set_point(0, i) assert len(utils.dask_cache.cache.heap.heap) > 0 # however, if the dask_cache attribute is deleted entirely (or set to None) # we will have no memory of it ever having been created. # and adding a new stack will spin up a cache del utils.dask_cache v.add_image(dask_stack, multiscale=False, contrast_limits=(0, 1)) assert utils.dask_cache.cache.available_bytes > 0
def test_dask_unoptimized_slicing(delayed_dask_stack, monkeypatch): """Prove that the dask_configure function works with a counterexample.""" # make sure we are not caching for this test, which also tests that we # can turn off caching utils.resize_dask_cache(0) assert utils.dask_cache.cache.available_bytes == 0 # mock the dask_configure function to return a no-op. def mock_dask_config(data): @contextmanager def dask_optimized_slicing(*args, **kwds): yield {} return dask_optimized_slicing monkeypatch.setattr(layers.base.base, 'configure_dask', mock_dask_config) # add dask stack to viewer. v = viewer.ViewerModel() dask_stack = delayed_dask_stack['stack'] v.add_image(dask_stack, multiscale=False, contrast_limits=(0, 1)) assert delayed_dask_stack['calls'] == 1 # without optimized dask slicing, we get a new call to the get_array func # (which "re-reads" the full z stack) EVERY time we change the Z plane # even though we've already read this full timepoint. for i in range(3): v.dims.set_point(1, i) assert delayed_dask_stack['calls'] == 1 + i # 😞 # of course we still incur calls when moving to a new timepoint... v.dims.set_point(0, 1) v.dims.set_point(0, 2) assert delayed_dask_stack['calls'] == 5 # without the cache we ALSO incur calls when returning to previously loaded # timepoints 😠v.dims.set_point(0, 1) v.dims.set_point(0, 0) v.dims.set_point(0, 3) # all told, we have ~2x as many calls as the optimized version above. # (should be exactly 8 calls, but for some reason, sometimes less on CI) assert delayed_dask_stack['calls'] >= 7
def test_dask_array_creates_cache(): """Test that adding a dask array creates a dask cache and turns of fusion. """ # by default we have no dask_cache and task fusion is active original = dask.config.get("optimization.fuse.active", None) def mock_set_view_slice(): assert dask.config.get("optimization.fuse.active") is False layer = layers.Image(da.ones((100, 100))) layer._set_view_slice = mock_set_view_slice layer.set_view_slice() # adding a dask array will turn on the cache, and turn off task fusion. assert isinstance(utils.dask_cache, dask.cache.Cache) assert dask.config.get("optimization.fuse.active", None) == original # if the dask version is too low to remove task fusion, emit a warning _dask_ver = dask.__version__ dask.__version__ = '2.14.0' with pytest.warns(UserWarning) as record: _ = layers.Image(da.ones((100, 100))) assert 'upgrade Dask to v2.15.0 or later' in record[0].message.args[0] dask.__version__ = _dask_ver # make sure we can resize the cache assert utils.dask_cache.cache.total_bytes > 1000 utils.resize_dask_cache(1000) assert utils.dask_cache.cache.total_bytes <= 1000 # This should only affect dask arrays, and not numpy data def mock_set_view_slice2(): assert dask.config.get("optimization.fuse.active", None) == original layer2 = layers.Image(np.ones((100, 100))) layer2._set_view_slice = mock_set_view_slice2 layer2.set_view_slice() # clean up cache utils.dask_cache = None
def test_dask_array_doesnt_create_cache(): """Test that dask arrays don't create cache but turns off fusion.""" # by default we have no dask_cache and task fusion is active original = dask.config.get("optimization.fuse.active", None) def mock_set_view_slice(): assert dask.config.get("optimization.fuse.active") is False layer = layers.Image(da.ones((100, 100))) layer._set_view_slice = mock_set_view_slice layer.set_view_slice() # adding a dask array won't create cache, but will turn off task fusion, # *but only* during slicing (see "mock_set_view_slice" above) assert utils.dask_cache is None assert dask.config.get("optimization.fuse.active", None) == original # if the dask version is too low to remove task fusion, emit a warning _dask_ver = dask.__version__ dask.__version__ = '2.14.0' with pytest.warns(UserWarning) as record: _ = layers.Image(da.ones((100, 100))) assert 'upgrade Dask to v2.15.0 or later' in record[0].message.args[0] dask.__version__ = _dask_ver # make sure we can resize the cache utils.resize_dask_cache(10000) assert utils.dask_cache.cache.available_bytes == 10000 # This should only affect dask arrays, and not numpy data def mock_set_view_slice2(): assert dask.config.get("optimization.fuse.active", None) == original layer2 = layers.Image(np.ones((100, 100))) layer2._set_view_slice = mock_set_view_slice2 layer2.set_view_slice() # clean up cache utils.dask_cache = None
def test_dask_unoptimized_slicing(delayed_dask_stack, monkeypatch): """Prove that the dask_configure function works with a counterexample.""" # we start with a cache...but then intentionally turn it off per-layer. resize_dask_cache(10000) assert _dask_utils._DASK_CACHE.cache.available_bytes == 10000 # add dask stack to viewer. v = ViewerModel() dask_stack = delayed_dask_stack['stack'] layer = v.add_image(dask_stack, cache=False) # the first and the middle stack will be loaded assert delayed_dask_stack['calls'] == 2 with layer.dask_optimized_slicing() as (_, cache): assert cache is None # without optimized dask slicing, we get a new call to the get_array func # (which "re-reads" the full z stack) EVERY time we change the Z plane # even though we've already read this full timepoint. current_z = v.dims.point[1] for i in range(3): v.dims.set_point(1, current_z + i) assert delayed_dask_stack['calls'] == 2 + i # 😞 # of course we still incur calls when moving to a new timepoint... initial_t = v.dims.point[0] v.dims.set_point(0, initial_t + 1) v.dims.set_point(0, initial_t + 2) assert delayed_dask_stack['calls'] == 6 # without the cache we ALSO incur calls when returning to previously loaded # timepoints 😠v.dims.set_point(0, initial_t + 1) v.dims.set_point(0, initial_t + 0) v.dims.set_point(0, initial_t + 3) # all told, we have ~2x as many calls as the optimized version above. # (should be exactly 9 calls, but for some reason, sometimes more on CI) assert delayed_dask_stack['calls'] >= 9
def test_dask_local_unoptimized_slicing(delayed_dask_stack, monkeypatch): """Prove that the dask_configure function works with a counterexample.""" # make sure we are not caching for this test, which also tests that we # can turn off caching resize_dask_cache(0) assert _dask_utils._DASK_CACHE.cache.available_bytes == 0 monkeypatch.setattr(layers.base.base, 'configure_dask', lambda *_: nullcontext) # add dask stack to viewer. v = ViewerModel() dask_stack = delayed_dask_stack['stack'] v.add_image(dask_stack, cache=False) # the first and the middle stack will be loaded assert delayed_dask_stack['calls'] == 2 # without optimized dask slicing, we get a new call to the get_array func # (which "re-reads" the full z stack) EVERY time we change the Z plane # even though we've already read this full timepoint. for i in range(3): v.dims.set_point(1, i) assert delayed_dask_stack['calls'] == 2 + 1 + i # 😞 # of course we still incur calls when moving to a new timepoint... v.dims.set_point(0, 1) v.dims.set_point(0, 2) assert delayed_dask_stack['calls'] == 7 # without the cache we ALSO incur calls when returning to previously loaded # timepoints 😠v.dims.set_point(0, 1) v.dims.set_point(0, 0) v.dims.set_point(0, 3) # all told, we have ~2x as many calls as the optimized version above. # (should be exactly 8 calls, but for some reason, sometimes less on CI) assert delayed_dask_stack['calls'] >= 10
def test_dask_cache_resizing(delayed_dask_stack): """Test that we can spin up, resize, and spin down the cache.""" # make sure we have a cache # big enough for 10+ (10, 10, 10) "timepoints" resize_dask_cache(100000) # add dask stack to the viewer, making sure to pass multiscale and clims v = ViewerModel() dask_stack = delayed_dask_stack['stack'] v.add_image(dask_stack) assert _dask_utils._DASK_CACHE.cache.available_bytes > 0 # make sure the cache actually has been populated assert len(_dask_utils._DASK_CACHE.cache.heap.heap) > 0 # we can resize that cache back to 0 bytes resize_dask_cache(0) assert _dask_utils._DASK_CACHE.cache.available_bytes == 0 # adding a 2nd stack should not adjust the cache size once created v.add_image(dask_stack) assert _dask_utils._DASK_CACHE.cache.available_bytes == 0 # and the cache will remain empty regardless of what we do for i in range(3): v.dims.set_point(1, i) assert len(_dask_utils._DASK_CACHE.cache.heap.heap) == 0 # but we can always spin it up again resize_dask_cache(1e4) assert _dask_utils._DASK_CACHE.cache.available_bytes == 1e4 # and adding a new image doesn't change the size v.add_image(dask_stack) assert _dask_utils._DASK_CACHE.cache.available_bytes == 1e4 # but the cache heap is getting populated again for i in range(3): v.dims.set_point(0, i) assert len(_dask_utils._DASK_CACHE.cache.heap.heap) > 0
#!/usr/bin/env python import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision.datasets import MNIST from torchvision import transforms from torch.utils.data import Dataset import napari from napari.utils import resize_dask_cache resize_dask_cache(0) mnist_train = MNIST( 'data/MNIST', download=True, transform=transforms.Compose([transforms.ToTensor(),]), train=True, ) mnist_test = MNIST( 'data/MNIST', download=True, transform=transforms.Compose([transforms.ToTensor(),]), train=False, )