def test_bad_bprop_def(): from myia.prim.grad_implementations import register_bprop from myia.prim import Primitive from myia.utils import InternalInferenceError with pytest.raises(InternalInferenceError): @register_bprop(Primitive('nonsense')) def _bprop_nonsense(x, y, dout): return dout + x + y
def _test_op(cls): import inspect op = Primitive(cls.__name__) nargs = len(inspect.getfullargspec(cls.impl).args) pyimpl_test[op] = cls.impl for method in dir(cls): pfx = 'infer_' if method.startswith(pfx): track = method[len(pfx):] if track == 'type': cons = type_inferrer_cons_test elif track == 'value': cons = value_inferrer_cons_test elif track == 'shape': cons = shape_inferrer_cons_test else: raise Exception(f'Unknown track to infer: {track}') inffn = getattr(cls, method) register_inferrer(op, nargs=nargs, constructors=cons)(inffn) return op
def _test_op(fn): prim = Primitive(fn.__name__) xinf = UniformPrimitiveInferrer.partial(impl=fn) abstract_inferrer_cons_test[prim] = xinf return prim
}) \ .select('resources', 'parse', 'resolve') \ .make_transformer('input', 'graph') specialize = scalar_pipeline \ .configure({ 'resources.convert.object_map': Merge({ operations.getitem: prim.tuple_getitem }) }) \ .select('resources', 'parse', 'resolve', 'infer', 'specialize') # We will optimize patterns of these fake primitives P = Primitive('P') Q = Primitive('Q') R = Primitive('R') idempotent_P = psub((P, (P, X)), (P, X), name='idempotent_P') elim_R = psub((R, X), X, name='elim_R') Q0_to_R = psub((Q, 0), (R, 0), name='Q0_to_R') QP_to_QR = psub((Q, (P, X)), (Q, (R, X)), name='QP_to_QR') multiply_by_zero_l = psub((prim.scalar_mul, 0, X), 0, name='multiply_by_zero_l')
import pytest from collections import Counter from myia.pipeline import scalar_parse as parse from myia.debug.label import short_labeler from myia.ir import manage, GraphManager, GraphCloner, ManagerError from myia.prim import Primitive from myia.utils import OrderedSet swap1 = Primitive('swap1') swap2 = Primitive('swap2') swap3 = Primitive('swap3') swap4 = Primitive('swap4') swap5 = Primitive('swap5') swap = swap1 swaps = [swap1, swap2, swap3, swap4, swap5] class NestingSpecs: def __init__(self, stage, specs): self.expected = self._parse_specs(specs) self.stage = stage def _parse_specs(self, specs): if specs is None: return None expected = {} for spec in specs.split(';'): spec = spec.strip()