コード例 #1
0
ファイル: gprint.py プロジェクト: fosterrath-mila/myia
def cosmetic_transformer(g):
    """Transform a graph so that it looks nicer.

    The resulting graph is not a valid one to run, because it may contain nodes
    with fake functions that only serve a cosmetic purpose.
    """
    spec = (
        _opt_distributed_constant,
        _opt_fancy_make_tuple,
        _opt_fancy_getitem,
        _opt_fancy_resolve,
        _opt_fancy_record_getitem,
        _opt_fancy_array_map,
        _opt_fancy_distribute,
        _opt_fancy_transpose,
        _opt_fancy_sum,
        _opt_fancy_unsafe_static_cast,
        _opt_fancy_scalar_to_array,
        _opt_fancy_array_to_scalar,
        _opt_fancy_hastag,
        _opt_fancy_casttag,
        _opt_fancy_tagged,
        # careful=True
    )
    nmap = NodeMap()
    for optim in spec:
        nmap.register(getattr(optim, 'interest', None), optim)
    optim = LocalPassOptimizer(nmap)
    optim(g)
    return g
コード例 #2
0
def test_type_tracking():

    pip = scalar_pipeline.with_steps(
        steps.step_parse,
        steps.step_infer,
        steps.step_specialize,
        steps.step_simplify_types,
        LocalPassOptimizer(opt_ok1, opt_ok2, opt_err1),
        steps.step_validate,
    )

    def fn_ok1(x, y):
        return x + y

    pip(input=fn_ok1, argspec=(to_abstract_test(i64), to_abstract_test(i64)))

    def fn_ok2(x):
        return -x

    pip(input=fn_ok2, argspec=(to_abstract_test(i64),))

    def fn_err1(x, y):
        return x - y

    with pytest.raises(ValidationError):
        pip(
            input=fn_err1,
            argspec=(to_abstract_test(i64), to_abstract_test(i64)),
        )
コード例 #3
0
def test_type_tracking_2():

    pip = scalar_pipeline.with_steps(
        steps.step_parse,
        steps.step_infer,
        steps.step_specialize,
        steps.step_simplify_types,
        LocalPassOptimizer(opt_ok1, opt_ok2, opt_err1),
        steps.step_validate,
    )

    def fn_err3(x, y):
        return x - y + x

    with pytest.raises(InferenceError):
        pip(
            input=fn_err3,
            argspec=(to_abstract_test(i64), to_abstract_test(i64)),
        )
コード例 #4
0
def _check_opt(before, after, *opts, argspec=None, argspec_after=None):
    nmap = NodeMap()
    for opt in opts:
        nmap.register(getattr(opt, 'interest', None), opt)
    eq = LocalPassOptimizer(nmap)
    _check_transform(before, after, eq, argspec, argspec_after)
コード例 #5
0
def _check_opt(before, after, *opts, argspec=None, argspec_after=None):
    eq = LocalPassOptimizer(*opts)
    _check_transform(before, after, eq, argspec, argspec_after)