예제 #1
0
def test_expression_statements():
    def f(x):  # pragma: no cover
        """Foo."""
        print(x)
        return x

    parse(f)
예제 #2
0
def test_maybe():
    def f():  # pragma: no cover
        while True:
            x = 2
        return x

    parse(f)
예제 #3
0
def test_unsupported_object():
    c = object()

    def f():  # pragma: no cover
        return c

    with pytest.raises(ValueError):
        parse(f)
예제 #4
0
def test_dfs_variants():
    def f(x):
        z = x * x

        def g(y):
            return y + z

        w = z + 3
        q = g(w)
        return q

    graph = parse(f)
    inner_graph_ct, = [x for x in dfs(graph.return_) if is_constant_graph(x)]
    inner_graph = inner_graph_ct.value

    inner_ret = inner_graph.return_

    deep = _name_nodes(_dfs(inner_ret, succ_deep))
    assert deep == set('. return add y z mul x'.split())

    deeper = _name_nodes(_dfs(inner_ret, succ_deeper))
    assert deeper == set('. return add y z mul x w 3 q g'.split())

    _bound_fv = freevars_boundary(inner_graph, True)
    bound_fv = _name_nodes(_dfs(inner_ret, succ_deeper, _bound_fv))
    assert bound_fv == set('. return add y z'.split())

    _no_fv = freevars_boundary(inner_graph, False)
    no_fv = _name_nodes(_dfs(inner_ret, succ_deeper, _no_fv))
    assert no_fv == set('. return add y'.split())

    _excl_root = exclude_from_set([inner_ret])
    excl_root = _name_nodes(_dfs(inner_ret, succ_deeper, _excl_root))
    assert excl_root == set()
예제 #5
0
        def test():
            gfn = parse(fn)

            def name(g):
                from myia.anf_ir import Constant, Graph
                if isinstance(g, Constant) and isinstance(g.value, Graph):
                    g = g.value
                gname = g.debug.name
                if gname == fn.__name__:
                    gname = 'X'
                return gname

            analysis = NestingAnalyzer(gfn)
            for g1, g2 in analysis.parents().items():
                if g2:
                    assert analysis.nested_in(g1, g2)
                    assert not analysis.nested_in(g2, g1)

            for g1, children in analysis.children().items():
                for child in children:
                    assert analysis.nested_in(child, g1)
                    assert not analysis.nested_in(g1, child)

            fvs = {}
            for g, vs in analysis.free_variables_total().items():
                if vs:
                    fvs[name(g)] = {name(v) for v in vs}
            assert fvs == expected_fvs_total

            fvs = {}
            for g, vs in analysis.free_variables_direct().items():
                if vs:
                    fvs[name(g)] = {name(v) for v in vs}
            assert fvs == expected_fvs_direct
예제 #6
0
 def test(args):
     if not isinstance(args, tuple):
         args = (args,)
     # TODO: avoid re-parsing every time
     fn2 = parse(fn)
     py_result = fn(*map(copy, args))
     myia_result = run(fn2, tuple(map(copy, args)))
     assert py_result == myia_result
예제 #7
0
파일: vm.py 프로젝트: jangocheng/myia
 def convert_value(self, value):
     """Translate the value to a format that the VM understands."""
     if isinstance(value, FunctionType):
         from myia.api import parse
         return parse(value)
     elif isinstance(value, CallableClosure):
         return value.closure
     else:
         return value
예제 #8
0
def test_closure_recur():
    # This cannot run with parse_compare since we need to reference the
    # top-level function

    def f(x, y):
        return fn(x - 1, y)

    def fn(x, y):
        def g(x):
            return x + 1
        if x == 0:
            return g(y)
        else:
            return f(x, g(y))

    fn2 = parse(fn)
    py_result = fn(1, 2)
    myia_result = run(fn2, (1, 2))
    assert py_result == myia_result
예제 #9
0
def test_unsupported():
    def f():  # pragma: no cover
        assert False

    with pytest.raises(NotImplementedError):
        parse(f)
예제 #10
0
def test_undefined():
    def f():  # pragma: no cover
        return c  # noqa

    with pytest.raises(ValueError):
        parse(f)