def test_inline_functions_protects_output_keys(): dsk = {'x': (inc, 1), 'y': (double, 'x')} assert inline_functions(dsk, [], [inc]) == {'y': (double, (inc, 1))} assert inline_functions(dsk, ['x'], [inc]) == { 'y': (double, 'x'), 'x': (inc, 1) }
def test_inline_functions_protects_output_keys(): dsk = {"x": (inc, 1), "y": (double, "x")} assert inline_functions(dsk, [], [inc]) == {"y": (double, (inc, 1))} assert inline_functions(dsk, ["x"], [inc]) == { "y": (double, "x"), "x": (inc, 1) }
def custom_delay_optimize( dsk: dict, keys: list, fast_functions=[], inline_patterns=[], **kwargs ) -> dict: """ Custom optimization functions for delayed tasks. By default only fusing of tasks will be carried out. Parameters ---------- dsk : dict Input dask task graph. keys : list Output task keys. fast_functions : list, optional List of fast functions to be inlined. By default `[]`. inline_patterns : list, optional List of patterns of task keys to be inlined. By default `[]`. Returns ------- dsk : dict Optimized dask graph. """ dsk, _ = fuse(ensure_dict(dsk), rename_keys=custom_fused_keys_renamer) if inline_patterns: dsk = inline_pattern(dsk, inline_patterns, inline_constants=False) if fast_functions: dsk = inline_functions( dsk, [], fast_functions=fast_functions, ) return dsk
def test_inline_functions(): x, y, i, d = "xyid" dsk = {"out": (add, i, d), i: (inc, x), d: (double, y), x: 1, y: 1} result = inline_functions(dsk, [], fast_functions=set([inc])) expected = {"out": (add, (inc, x), d), d: (double, y), x: 1, y: 1} assert result == expected
def test_inline_ignores_curries_and_partials(): dsk = {'x': 1, 'y': 2, 'a': (partial(add, 1), 'x'), 'b': (inc, 'a')} result = inline_functions(dsk, [], fast_functions=set([add])) assert result['b'] == (inc, dsk['a']) assert 'a' not in result
def optimize( dsk, keys, fuse_keys=None, fast_functions=None, inline_functions_fast_functions=(getter_inline,), rename_fused_keys=True, **kwargs, ): """Optimize dask for array computation 1. Cull tasks not necessary to evaluate keys 2. Remove full slicing, e.g. x[:] 3. Inline fast functions like getitem and np.transpose """ if not isinstance(keys, (list, set)): keys = [keys] keys = list(flatten(keys)) if not isinstance(dsk, HighLevelGraph): dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) dsk = optimize_blockwise(dsk, keys=keys) dsk = fuse_roots(dsk, keys=keys) dsk = dsk.cull(set(keys)) # Perform low-level fusion unless the user has # specified False explicitly. if config.get("optimization.fuse.active") is False: return dsk dependencies = dsk.get_all_dependencies() dsk = ensure_dict(dsk) # Low level task optimizations if fast_functions is not None: inline_functions_fast_functions = fast_functions hold = hold_keys(dsk, dependencies) dsk, dependencies = fuse( dsk, hold + keys + (fuse_keys or []), dependencies, rename_keys=rename_fused_keys, ) if inline_functions_fast_functions: dsk = inline_functions( dsk, keys, dependencies=dependencies, fast_functions=inline_functions_fast_functions, ) return optimize_slices(dsk)
def test_inline_traverses_lists(): x, y, i, d = 'xyid' dsk = {'out': (sum, [i, d]), i: (inc, x), d: (double, y), x: 1, y: 1} expected = {'out': (sum, [(inc, x), d]), d: (double, y), x: 1, y: 1} result = inline_functions(dsk, [], fast_functions=set([inc])) assert result == expected
def test_inline_functions(): x, y, i, d = 'xyid' dsk = {'out': (add, i, d), i: (inc, x), d: (double, y), x: 1, y: 1} result = inline_functions(dsk, [], fast_functions=set([inc])) expected = {'out': (add, (inc, x), d), d: (double, y), x: 1, y: 1} assert result == expected
def test_inline_functions_non_hashable(): class NonHashableCallable(object): def __call__(self, a): return a + 1 def __hash__(self): raise TypeError("Not hashable") nohash = NonHashableCallable() dsk = {'a': 1, 'b': (inc, 'a'), 'c': (nohash, 'b'), 'd': (inc, 'c')} result = inline_functions(dsk, [], fast_functions={inc}) assert result['c'] == (nohash, dsk['b']) assert 'b' not in result
def test_inline_functions_non_hashable(): class NonHashableCallable: def __call__(self, a): return a + 1 def __hash__(self): raise TypeError("Not hashable") nohash = NonHashableCallable() dsk = {"a": 1, "b": (inc, "a"), "c": (nohash, "b"), "d": (inc, "c")} result = inline_functions(dsk, [], fast_functions={inc}) assert result["c"] == (nohash, dsk["b"]) assert "b" not in result
def test_inline_doesnt_shrink_fast_functions_at_top(): dsk = {"x": (inc, "y"), "y": 1} result = inline_functions(dsk, [], fast_functions=set([inc])) assert result == dsk
def test_inline_ignores_curries_and_partials(): dsk = {"x": 1, "y": 2, "a": (partial(add, 1), "x"), "b": (inc, "a")} result = inline_functions(dsk, [], fast_functions=set([add])) assert result["b"] == (inc, dsk["a"]) assert "a" not in result
def time_inline_functions(self): inline_functions( self.dsk, self.keys, fast_functions=[inc], dependencies=self.deps )
def test_inline_functions_protects_output_keys(): dsk = {'x': (inc, 1), 'y': (double, 'x')} assert inline_functions(dsk, [], [inc]) == {'y': (double, (inc, 1))} assert inline_functions(dsk, ['x'], [inc]) == {'y': (double, 'x'), 'x': (inc, 1)}
def test_inline_doesnt_shrink_fast_functions_at_top(): dsk = {'x': (inc, 'y'), 'y': 1} result = inline_functions(dsk, [], fast_functions=set([inc])) assert result == dsk
def test_inline_traverses_lists(): x, y, i, d = "xyid" dsk = {"out": (sum, [i, d]), i: (inc, x), d: (double, y), x: 1, y: 1} expected = {"out": (sum, [(inc, x), d]), d: (double, y), x: 1, y: 1} result = inline_functions(dsk, [], fast_functions={inc}) assert result == expected