def _eval_on_shapes(f, *args): """Evaluates f given only shapes and types.""" def abstractify(x): return jax.abstract_arrays.raise_to_shaped(jax.core.get_aval(x)) def make_array(arg): return backend.numpy.zeros(shape=arg.shape, dtype=arg.dtype) def turn_back_into_pytree(x): if isinstance(x, jax.core.JaxTuple): return tuple([turn_back_into_pytree(y) for y in x]) return x def get_shapes_and_types(x): if isinstance(x, jax.core.AbstractTuple): return tuple([get_shapes_and_types(y) for y in x]) return ShapeType(x.shape, x.dtype) def f_jaxtuple(*jaxtuple_args): args = map(turn_back_into_pytree, jaxtuple_args) out = f(*args) res, _ = jax.api_util.pytree_to_jaxtupletree(out) return res args_arrays = nested_map(args, make_array) jaxtuple_args, _ = jax.util.unzip2( map(jax.api_util.pytree_to_jaxtupletree, args_arrays)) res = pe.abstract_eval_fun(f_jaxtuple, *map(abstractify, jaxtuple_args)) return get_shapes_and_types(res)
def layer_abstract_eval(*avals): akey = ShapedArray((2, ), 'uint32') def init_and_apply(key, *inputs): params = init_fun(key, *inputs) return apply_fun(params, *inputs) return pe.abstract_eval_fun(init_and_apply, akey, *avals)
def _grid_trace_shape(fn, *args, **kwargs): """Traces a function to compute the shape of its output.""" shaped_args = [] for arg in args: if isinstance(arg, np.ndarray): shaped_args += [ShapedArray(tuple(arg.shape), arg.dtype)] else: shaped_args += [arg] return pe.abstract_eval_fun(fn, *shaped_args, **kwargs).shape
def _canonicalize_displacement_or_metric(displacement_or_metric): """Checks whether or not a displacement or metric was provided.""" for dim in range(4): try: R = ShapedArray((1, dim), f32) dR_or_dr = pe.abstract_eval_fun(displacement_or_metric, R, R, t=0) if len(dR_or_dr.shape) == 2: return displacement_or_metric else: return space.metric(displacement_or_metric) except ValueError: continue raise ValueError( 'Canonicalize displacement not implemented for spatial dimension larger' 'than 4.')
def _abstract(*flat_avals, **params): return pe.abstract_eval_fun(self.impl, *flat_avals, **params)
def _infer_shape_jax(f, *vals, **params): avals = map(abstractify, vals) return pe.abstract_eval_fun( lambda *a, **k: tree_util.tree_leaves(f(*a, **k)), *avals, **params)