Example #1
0
    def wrapped(shape, *args, **kwargs):
        if isinstance(shape, collections.abc.Iterable):
            shape = tuple(int(s) for s in shape)
        else:
            shape = (int(shape), )

        # Estimate 100 Mi elements per block
        blocksize = int((100 * (2**20))**(1 / len(shape)))

        chunks = []
        for l in shape:
            chunks.append([])
            while l > 0:
                s = max(min(blocksize, l), 0)
                chunks[-1].append(s)
                l -= s

        name = func.__name__ + "-" + hex(random.randrange(2**64))
        dsk = {}
        with skip_backend(sys.modules[__name__]):
            for chunk_id in itertools.product(
                    *map(lambda x: range(len(x)), chunks)):
                shape = tuple(chunks[i][j] for i, j in enumerate(chunk_id))
                dsk[(name, ) + chunk_id] = func(shape, *args, **kwargs)

            meta = func(tuple(0 for _ in shape), *args, **kwargs)
            dtype = str(meta.dtype)

        return da.Array(dsk, name, chunks, dtype=dtype, meta=meta)
Example #2
0
def test_pickle_state():
    ua.set_global_backend(ComparableBackend("a"))
    ua.register_backend(ComparableBackend("b"))

    with ua.set_backend(ComparableBackend("c")), ua.skip_backend(
            ComparableBackend("d")):
        state = ua.get_state()

    state_loaded = pickle.loads(pickle.dumps(state))

    assert state._pickle() == state_loaded._pickle()
Example #3
0
def test_getset_state(cleanup_backends):
    ua.set_global_backend(Backend())
    ua.register_backend(Backend())

    with ua.set_backend(Backend()), ua.skip_backend(Backend()):
        state = ua.get_state()

    pstate = state._pickle()

    assert pstate != ua.get_state()._pickle()

    with ua.set_state(state):
        assert pstate[:2] == ua.get_state()._pickle()[:2]
Example #4
0
def test_skip_comparison(nullary_mm):
    be1 = Backend()
    be1.__ua_function__ = lambda f, a, kw: None

    class Backend2(Backend):
        @staticmethod
        def __ua_function__(f, a, kw):
            pass

        def __eq__(self, other):
            return other is self or other is be1

    with pytest.raises(ua.BackendNotImplementedError):
        with ua.set_backend(be1), ua.skip_backend(Backend2()):
            nullary_mm()
Example #5
0
def test_multidomain_backends():
    n_domains = 2
    be = DisableBackend(domain=["ua_tests" + str(i) for i in range(n_domains)])

    mms = [
        ua.generate_multimethod(lambda: (), lambda a, kw, d: (a, kw),
                                "ua_tests" + str(i)) for i in range(n_domains)
    ]

    def assert_no_backends():
        for i in range(len(mms)):
            with pytest.raises(ua.BackendNotImplementedError):
                mms[i]()

    def assert_backend_active(backend):
        assert all(mms[i]() is backend.ret for i in range(len(mms)))

    assert_no_backends()

    with ua.set_backend(be):
        assert_backend_active(be)

    ua.set_global_backend(be)
    assert_backend_active(be)

    with ua.skip_backend(be):
        assert_no_backends()

    assert_backend_active(be)

    for i in range(len(mms)):
        ua.clear_backends(mms[i].domain, globals=True)

        with pytest.raises(ua.BackendNotImplementedError):
            mms[i]()

        for j in range(i + 1, len(mms)):
            assert mms[j]() is be.ret

    assert_no_backends()

    ua.register_backend(be)
    assert_backend_active(be)
Example #6
0
def test_skip_raises(nullary_mm):
    be1 = Backend()
    be1.__ua_function__ = lambda f, a, kw: None

    foo = Exception("Foo")

    class Backend2(Backend):
        @staticmethod
        def __ua_function__(f, a, kw):
            pass

        def __eq__(self, other):
            raise foo

    with pytest.raises(Exception) as e:
        with ua.set_backend(be1), ua.skip_backend(Backend2()):
            nullary_mm()

    assert e.value is foo
Example #7
0
 def be2_ua_func(f, a, kw):
     with ua.skip_backend(be_inner):
         return f(*a, **kw)
Example #8
0
 def wrapped(*args, **kwargs):
     with skip_backend(sys.modules[__name__]):
         return da.map_blocks(func, *args, **kwargs)
Example #9
0
 def wrapped(*args, **kwargs):
     with skip_backend(sys.modules[__name__]):
         return da.map_blocks(wrap_current_state(func), *args, **kwargs)
Example #10
0
import sys
import uarray as ua

from ._uarray_plug import __ua_domain__, __ua_convert__, __ua_function__
from ._func_diff_registry import FunctionDifferentialRegistry, register_diff, diff
from ._diff_array import Variable, DiffArray, ArrayDiffRegistry

__all__ = [
    "__ua_domain__",
    "__ua_convert__",
    "__ua_function__",
    "FunctionDifferentialRegistry",
    "register_diff",
    "diff",
    "Variable",
    "DiffArray",
    "ArrayDiffRegistry",
]

SKIP_SELF = ua.skip_backend(sys.modules["udiff"])