示例#1
0
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
                           named_shape_rule, *avals, **kwargs):
    assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
    assert not prim.multiple_results
    weak_type = weak_type_rule(*avals, **kwargs)
    least_specialized = _max(
        map(type, avals), key=operator.attrgetter('array_abstraction_level'))
    if least_specialized is core.ConcreteArray:
        out = prim.impl(*[x.val for x in avals], **kwargs)
        return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
    elif least_specialized is core.ShapedArray:
        return core.ShapedArray(shape_rule(*avals, **kwargs),
                                dtype_rule(*avals, **kwargs),
                                weak_type=weak_type,
                                named_shape=named_shape_rule(*avals, **kwargs))
    elif least_specialized is core.DShapedArray:
        shape = shape_rule(*avals, **kwargs)
        ty = (core.ShapedArray if all(type(d) is int
                                      for d in shape) else core.DShapedArray)
        return ty(shape, dtype_rule(*avals, **kwargs), weak_type)
    elif least_specialized is core.UnshapedArray:
        return core.UnshapedArray(dtype_rule(*avals, **kwargs),
                                  weak_type=weak_type)
    else:
        raise TypeError(avals, least_specialized)
示例#2
0
def standard_multi_result_abstract_eval(prim, shape_rule, dtype_rule,
                                        weak_type_rule, named_shape_rule,
                                        *avals, **kwargs):
    assert prim.multiple_results
    assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
    least_specialized = _max(
        map(type, avals), key=operator.attrgetter('array_abstraction_level'))
    weak_types = weak_type_rule(*avals, **kwargs)
    if least_specialized is core.ConcreteArray:
        out_vals = prim.impl(*[x.val for x in avals], **kwargs)
        return [
            core.ConcreteArray(val.dtype, val, weak_type=weak_type)
            for val, weak_type in safe_zip(out_vals, weak_types)
        ]
    elif least_specialized is core.ShapedArray:
        out_shapes = shape_rule(*avals, **kwargs)
        out_dtypes = dtype_rule(*avals, **kwargs)
        out_named_shapes = named_shape_rule(*avals, **kwargs)
        return [
            core.ShapedArray(s,
                             d,
                             weak_type=weak_type,
                             named_shape=named_shape)
            for s, d, weak_type, named_shape in safe_zip(
                out_shapes, out_dtypes, weak_types, out_named_shapes)
        ]
    elif least_specialized is core.UnshapedArray:
        out_dtypes = dtype_rule(*avals, **kwargs)
        return [
            core.UnshapedArray(dtype, weak_type=weak_type)
            for dtype, weak_type in safe_zip(out_dtypes, weak_types)
        ]
    else:
        raise TypeError(avals, least_specialized)
示例#3
0
文件: prng.py 项目: Jakob-Unfried/jax
def _threefry2x32_abstract_eval(*args):
  if any(a.dtype != jnp.uint32 for a in args):
    raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
                    .format(args))
  if all(isinstance(arg, core.ShapedArray) for arg in args):
    shape = lax._broadcasting_shape_rule(*args)
    named_shape = core.join_named_shapes(*(a.named_shape for a in args))
    aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
  else:
    aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
  return (aval,) * 2