def cosmetic_transformer(g): """Transform a graph so that it looks nicer. The resulting graph is not a valid one to run, because it may contain nodes with fake functions that only serve a cosmetic purpose. """ spec = ( _opt_distributed_constant, _opt_fancy_make_tuple, _opt_fancy_getitem, _opt_fancy_resolve, _opt_fancy_record_getitem, _opt_fancy_array_map, _opt_fancy_distribute, _opt_fancy_transpose, _opt_fancy_sum, _opt_fancy_unsafe_static_cast, _opt_fancy_scalar_to_array, _opt_fancy_array_to_scalar, _opt_fancy_hastag, _opt_fancy_casttag, _opt_fancy_tagged, # careful=True ) nmap = NodeMap() for optim in spec: nmap.register(getattr(optim, 'interest', None), optim) optim = LocalPassOptimizer(nmap) optim(g) return g
def test_type_tracking(): pip = scalar_pipeline.with_steps( steps.step_parse, steps.step_infer, steps.step_specialize, steps.step_simplify_types, LocalPassOptimizer(opt_ok1, opt_ok2, opt_err1), steps.step_validate, ) def fn_ok1(x, y): return x + y pip(input=fn_ok1, argspec=(to_abstract_test(i64), to_abstract_test(i64))) def fn_ok2(x): return -x pip(input=fn_ok2, argspec=(to_abstract_test(i64),)) def fn_err1(x, y): return x - y with pytest.raises(ValidationError): pip( input=fn_err1, argspec=(to_abstract_test(i64), to_abstract_test(i64)), )
def test_type_tracking_2(): pip = scalar_pipeline.with_steps( steps.step_parse, steps.step_infer, steps.step_specialize, steps.step_simplify_types, LocalPassOptimizer(opt_ok1, opt_ok2, opt_err1), steps.step_validate, ) def fn_err3(x, y): return x - y + x with pytest.raises(InferenceError): pip( input=fn_err3, argspec=(to_abstract_test(i64), to_abstract_test(i64)), )
def _check_opt(before, after, *opts, argspec=None, argspec_after=None): nmap = NodeMap() for opt in opts: nmap.register(getattr(opt, 'interest', None), opt) eq = LocalPassOptimizer(nmap) _check_transform(before, after, eq, argspec, argspec_after)
def _check_opt(before, after, *opts, argspec=None, argspec_after=None): eq = LocalPassOptimizer(*opts) _check_transform(before, after, eq, argspec, argspec_after)