Esempio n. 1
0
def jax_funcify_FunctionGraph(
    fgraph,
    node=None,
    fgraph_name="jax_funcified_fgraph",
    **kwargs,
):
    return fgraph_to_python(
        fgraph,
        jax_funcify,
        type_conversion_fn=jax_typify,
        fgraph_name=fgraph_name,
        **kwargs,
    )
Esempio n. 2
0
def test_fgraph_to_python_once():
    """Make sure that an output is only computed once when it's referenced multiple times."""

    x = vector("x")
    y = vector("y")

    class TestOp(Op):
        def __init__(self):
            self.called = 0

        def make_node(self, *args):
            return Apply(self, list(args), [x.type() for x in args])

        def perform(self, inputs, outputs):
            for i, inp in enumerate(inputs):
                outputs[i][0] = inp[0]

    @to_python.register(TestOp)
    def to_python_TestOp(op, **kwargs):
        def func(*args, op=op):
            op.called += 1
            return list(args)

        return func

    op1 = TestOp()
    op2 = TestOp()

    q, r = op1(x, y)
    outs = op2(q + r, q + r)

    out_fg = FunctionGraph([x, y], outs, clone=False)
    assert len(out_fg.outputs) == 2

    out_py = fgraph_to_python(out_fg, to_python)

    x_val = np.r_[1, 2].astype(config.floatX)
    y_val = np.r_[2, 3].astype(config.floatX)

    res = out_py(x_val, y_val)
    assert len(res) == 2
    assert op1.called == 1
    assert op2.called == 1

    res = out_py(x_val, y_val)
    assert len(res) == 2
    assert op1.called == 2
    assert op2.called == 2
Esempio n. 3
0
def test_fgraph_to_python_multiline_str():
    """Make sure that multiline `__str__` values are supported by `fgraph_to_python`."""

    x = vector("x")
    y = vector("y")

    class TestOp(Op):
        def __init__(self):
            super().__init__()

        def make_node(self, *args):
            return Apply(self, list(args), [x.type() for x in args])

        def perform(self, inputs, outputs):
            for i, inp in enumerate(inputs):
                outputs[i][0] = inp[0]

        def __str__(self):
            return "Test\nOp()"

    @to_python.register(TestOp)
    def to_python_TestOp(op, **kwargs):
        def func(*args, op=op):
            return list(args)

        return func

    op1 = TestOp()
    op2 = TestOp()

    q, r = op1(x, y)
    outs = op2(q + r, q + r)

    out_fg = FunctionGraph([x, y], outs, clone=False)
    assert len(out_fg.outputs) == 2

    out_py = fgraph_to_python(out_fg, to_python)

    out_py_src = inspect.getsource(out_py)

    assert ("""
    # Elemwise{add,no_inplace}(Test
    # Op().0, Test
    # Op().1)
    """ in out_py_src)
Esempio n. 4
0
def test_fgraph_to_python_names():
    import inspect

    x = scalar("1x")
    y = scalar("_")
    z = scalar()
    q = scalar("def")
    r = NoneConst

    out_fg = FunctionGraph([x, y, z, q, r], [x, y, z, q, r], clone=False)
    out_jx = fgraph_to_python(out_fg, to_python)

    sig = inspect.signature(out_jx)
    assert (x.auto_name, "_", z.auto_name, q.auto_name,
            r.name) == tuple(sig.parameters.keys())
    assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5)

    obj = object()
    assert get_name_for_object(obj) == type(obj).__name__