def test_atop_non_atop_output(): x = da.ones(10, chunks=(5, )) y = (((x + 1) + 2) + 3) w = y.sum() z = (((y * 2) * 3) * 4) z_top_before = tuple(z.dask.dicts[z.name].indices) (zz, ) = dask.optimize(z) z_top_after = tuple(z.dask.dicts[z.name].indices) assert z_top_before == z_top_after, "z_top mutated" dsk = optimize_atop(z.dask, keys=list(dask.core.flatten(z.__dask_keys__()))) assert isinstance(dsk, HighLevelGraph) assert len( [layer for layer in dsk.dicts.values() if isinstance(layer, TOP)]) == 1 dsk = optimize_atop( HighLevelGraph.merge(w.dask, z.dask), keys=list(dask.core.flatten([w.__dask_keys__(), z.__dask_keys__()]))) assert isinstance(dsk, HighLevelGraph) assert len( [layer for layer in z.dask.dicts.values() if isinstance(layer, TOP)]) >= 1
def test_atop_non_atop_output(): x = da.ones(10, chunks=(5,)) y = (((x + 1) + 2) + 3) w = y.sum() z = (((y * 2) * 3) * 4) z_top_before = tuple(z.dask.dicts[z.name].indices) (zz,) = dask.optimize(z) z_top_after = tuple(z.dask.dicts[z.name].indices) assert z_top_before == z_top_after, "z_top mutated" dsk = optimize_atop(z.dask, keys=list(dask.core.flatten(z.__dask_keys__()))) assert isinstance(dsk, dask.sharedict.ShareDict) assert len([layer for layer in dsk.dicts.values() if isinstance(layer, TOP)]) == 1 dsk = optimize_atop(dask.sharedict.merge(w.dask, z.dask), keys=list(dask.core.flatten([w.__dask_keys__(), z.__dask_keys__()]))) assert isinstance(dsk, dask.sharedict.ShareDict) assert len([layer for layer in z.dask.dicts.values() if isinstance(layer, TOP)]) >= 1
def test_dont_merge_before_reductions(): x = da.ones(10, chunks=(5,)) y = da.atop(inc, 'i', x, 'i', dtype=x.dtype) z = da.atop(sum, '', y, 'i', dtype=y.dtype) w = da.atop(sum, '', z, '', dtype=y.dtype) dsk = optimize_atop(w.dask) assert len([d for d in dsk.dicts.values() if isinstance(d, TOP)]) == 2 z.compute()
def test_dont_merge_before_reductions(): x = da.ones(10, chunks=(5, )) y = da.atop(inc, 'i', x, 'i', dtype=x.dtype) z = da.atop(sum, '', y, 'i', dtype=y.dtype) w = da.atop(sum, '', z, '', dtype=y.dtype) dsk = optimize_atop(w.dask) assert len([d for d in dsk.dicts.values() if isinstance(d, TOP)]) == 2 z.compute()