Beispiel #1
0
def test_bad_bprop_def():
    from myia.prim.grad_implementations import register_bprop
    from myia.prim import Primitive
    from myia.utils import InternalInferenceError

    with pytest.raises(InternalInferenceError):
        @register_bprop(Primitive('nonsense'))
        def _bprop_nonsense(x, y, dout):
            return dout + x + y
Beispiel #2
0
def _test_op(cls):
    import inspect
    op = Primitive(cls.__name__)
    nargs = len(inspect.getfullargspec(cls.impl).args)
    pyimpl_test[op] = cls.impl
    for method in dir(cls):
        pfx = 'infer_'
        if method.startswith(pfx):
            track = method[len(pfx):]
            if track == 'type':
                cons = type_inferrer_cons_test
            elif track == 'value':
                cons = value_inferrer_cons_test
            elif track == 'shape':
                cons = shape_inferrer_cons_test
            else:
                raise Exception(f'Unknown track to infer: {track}')
            inffn = getattr(cls, method)
            register_inferrer(op, nargs=nargs, constructors=cons)(inffn)
    return op
Beispiel #3
0
def _test_op(fn):
    prim = Primitive(fn.__name__)
    xinf = UniformPrimitiveInferrer.partial(impl=fn)
    abstract_inferrer_cons_test[prim] = xinf
    return prim
Beispiel #4
0
    }) \
    .select('resources', 'parse', 'resolve') \
    .make_transformer('input', 'graph')


specialize = scalar_pipeline \
    .configure({
        'resources.convert.object_map': Merge({
            operations.getitem: prim.tuple_getitem
        })
    }) \
    .select('resources', 'parse', 'resolve', 'infer', 'specialize')

# We will optimize patterns of these fake primitives

P = Primitive('P')
Q = Primitive('Q')
R = Primitive('R')

idempotent_P = psub((P, (P, X)), (P, X), name='idempotent_P')

elim_R = psub((R, X), X, name='elim_R')

Q0_to_R = psub((Q, 0), (R, 0), name='Q0_to_R')

QP_to_QR = psub((Q, (P, X)), (Q, (R, X)), name='QP_to_QR')

multiply_by_zero_l = psub((prim.scalar_mul, 0, X),
                          0,
                          name='multiply_by_zero_l')
Beispiel #5
0
import pytest

from collections import Counter

from myia.pipeline import scalar_parse as parse
from myia.debug.label import short_labeler
from myia.ir import manage, GraphManager, GraphCloner, ManagerError
from myia.prim import Primitive
from myia.utils import OrderedSet

swap1 = Primitive('swap1')
swap2 = Primitive('swap2')
swap3 = Primitive('swap3')
swap4 = Primitive('swap4')
swap5 = Primitive('swap5')
swap = swap1

swaps = [swap1, swap2, swap3, swap4, swap5]


class NestingSpecs:
    def __init__(self, stage, specs):
        self.expected = self._parse_specs(specs)
        self.stage = stage

    def _parse_specs(self, specs):
        if specs is None:
            return None
        expected = {}
        for spec in specs.split(';'):
            spec = spec.strip()