def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) least_specialized = _max( map(type, avals), key=operator.attrgetter('array_abstraction_level')) if least_specialized is core.ConcreteArray: out = prim.impl(*[x.val for x in avals], **kwargs) return core.ConcreteArray(out.dtype, out, weak_type=weak_type) elif least_specialized is core.ShapedArray: return core.ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type=weak_type, named_shape=named_shape_rule(*avals, **kwargs)) elif least_specialized is core.DShapedArray: shape = shape_rule(*avals, **kwargs) ty = (core.ShapedArray if all(type(d) is int for d in shape) else core.DShapedArray) return ty(shape, dtype_rule(*avals, **kwargs), weak_type) elif least_specialized is core.UnshapedArray: return core.UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type) else: raise TypeError(avals, least_specialized)
def standard_multi_result_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals least_specialized = _max( map(type, avals), key=operator.attrgetter('array_abstraction_level')) weak_types = weak_type_rule(*avals, **kwargs) if least_specialized is core.ConcreteArray: out_vals = prim.impl(*[x.val for x in avals], **kwargs) return [ core.ConcreteArray(val.dtype, val, weak_type=weak_type) for val, weak_type in safe_zip(out_vals, weak_types) ] elif least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) out_named_shapes = named_shape_rule(*avals, **kwargs) return [ core.ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape) for s, d, weak_type, named_shape in safe_zip( out_shapes, out_dtypes, weak_types, out_named_shapes) ] elif least_specialized is core.UnshapedArray: out_dtypes = dtype_rule(*avals, **kwargs) return [ core.UnshapedArray(dtype, weak_type=weak_type) for dtype, weak_type in safe_zip(out_dtypes, weak_types) ] else: raise TypeError(avals, least_specialized)
def _threefry2x32_abstract_eval(*args): if any(a.dtype != jnp.uint32 for a in args): raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}" .format(args)) if all(isinstance(arg, core.ShapedArray) for arg in args): shape = lax._broadcasting_shape_rule(*args) named_shape = core.join_named_shapes(*(a.named_shape for a in args)) aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape) else: aval = core.UnshapedArray(jnp.dtype(jnp.uint32)) return (aval,) * 2