Beispiel #1
0
def _eval_on_shapes(f, *args):
    """Evaluates f given only shapes and types."""
    def abstractify(x):
        return jax.abstract_arrays.raise_to_shaped(jax.core.get_aval(x))

    def make_array(arg):
        return backend.numpy.zeros(shape=arg.shape, dtype=arg.dtype)

    def turn_back_into_pytree(x):
        if isinstance(x, jax.core.JaxTuple):
            return tuple([turn_back_into_pytree(y) for y in x])
        return x

    def get_shapes_and_types(x):
        if isinstance(x, jax.core.AbstractTuple):
            return tuple([get_shapes_and_types(y) for y in x])
        return ShapeType(x.shape, x.dtype)

    def f_jaxtuple(*jaxtuple_args):
        args = map(turn_back_into_pytree, jaxtuple_args)
        out = f(*args)
        res, _ = jax.api_util.pytree_to_jaxtupletree(out)
        return res

    args_arrays = nested_map(args, make_array)
    jaxtuple_args, _ = jax.util.unzip2(
        map(jax.api_util.pytree_to_jaxtupletree, args_arrays))
    res = pe.abstract_eval_fun(f_jaxtuple, *map(abstractify, jaxtuple_args))

    return get_shapes_and_types(res)
Beispiel #2
0
        def layer_abstract_eval(*avals):
            akey = ShapedArray((2, ), 'uint32')

            def init_and_apply(key, *inputs):
                params = init_fun(key, *inputs)
                return apply_fun(params, *inputs)

            return pe.abstract_eval_fun(init_and_apply, akey, *avals)
Beispiel #3
0
def _grid_trace_shape(fn, *args, **kwargs):
    """Traces a function to compute the shape of its output."""
    shaped_args = []
    for arg in args:
        if isinstance(arg, np.ndarray):
            shaped_args += [ShapedArray(tuple(arg.shape), arg.dtype)]
        else:
            shaped_args += [arg]
    return pe.abstract_eval_fun(fn, *shaped_args, **kwargs).shape
Beispiel #4
0
def _canonicalize_displacement_or_metric(displacement_or_metric):
    """Checks whether or not a displacement or metric was provided."""
    for dim in range(4):
        try:
            R = ShapedArray((1, dim), f32)
            dR_or_dr = pe.abstract_eval_fun(displacement_or_metric, R, R, t=0)
            if len(dR_or_dr.shape) == 2:
                return displacement_or_metric
            else:
                return space.metric(displacement_or_metric)
        except ValueError:
            continue
    raise ValueError(
        'Canonicalize displacement not implemented for spatial dimension larger'
        'than 4.')
Beispiel #5
0
 def _abstract(*flat_avals, **params):
   return pe.abstract_eval_fun(self.impl, *flat_avals, **params)
Beispiel #6
0
def _infer_shape_jax(f, *vals, **params):
    avals = map(abstractify, vals)
    return pe.abstract_eval_fun(
        lambda *a, **k: tree_util.tree_leaves(f(*a, **k)), *avals, **params)