def test_var_ordering(self): newsym = core.gensym() a = newsym(core.ShapedArray((), np.dtype('int32'))) b = newsym(core.ShapedArray((), np.dtype('int32'))) c = newsym(core.ShapedArray((), np.dtype('int32'))) for ordering in it.permutations([a, b, c]): assert sorted(list(ordering)) == [a, b, c]
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree, in_avals, primitive_name: str): # When staging the branches of a conditional into jaxprs, constants are # extracted from each branch and converted to jaxpr arguments. To use the # staged jaxprs as the branches to a conditional *primitive*, we need for # their (input) signatures to match. This function "joins" the staged jaxprs: # for each one, it makes another that accepts *all* constants, but only uses # those that it needs (dropping the rest). jaxprs, all_consts, all_out_trees = \ unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name) for fun in funs) newvar = core.gensym(jaxprs, suffix='_') all_const_avals = [map(_abstractify, consts) for consts in all_consts] unused_const_vars = [ map(newvar, const_avals) for const_avals in all_const_avals ] def pad_jaxpr_constvars(i, jaxpr): prefix = util.concatenate(unused_const_vars[:i]) suffix = util.concatenate(unused_const_vars[i + 1:]) constvars = [*prefix, *jaxpr.constvars, *suffix] return jaxpr.replace(constvars=constvars) consts = util.concatenate(all_consts) jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] closed_jaxprs = [ core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) for jaxpr in jaxprs ] return closed_jaxprs, consts, all_out_trees
def test_var_ordering(self): newsym = core.gensym() a = newsym(core.abstract_unit) b = newsym(core.abstract_unit) c = newsym(core.abstract_unit) for ordering in it.permutations([a, b, c]): assert sorted(list(ordering)) == [a, b, c]
def test_var_tree_flatten(self): newsym = core.gensym() a, b, c, d = ( newsym(core.abstract_unit), newsym(core.abstract_unit), newsym(core.abstract_unit), newsym(core.abstract_unit)) syms = {c: d, a: b} assert 'bd' == ''.join(map(str, tree_leaves(syms)))
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, has_output_token: bool) -> core.Jaxpr: """Rewrite a Jaxpr to thread the token, if needed.""" assert has_input_token or not has_output_token if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr): return jaxpr mk_new_var = core.gensym([jaxpr]) eqns: List[core.JaxprEqn] = [] last_token_var = mk_new_var(core.abstract_token) # store the incoming token if has_input_token: invars = jaxpr.invars + [last_token_var] else: invars = jaxpr.invars eqns.append( core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var], lax.create_token_p, {}, source_info_util.current())) for eqn in jaxpr.eqns: if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params): eqns.append(eqn) else: output_token_var = mk_new_var(core.abstract_token) _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var) last_token_var = output_token_var outvars = jaxpr.outvars + ([last_token_var] if has_output_token else []) new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns) return new_jaxpr
def test_jaxpr_undefined_eqn_invar(self): jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos') cos.invars[0] = core.gensym([jaxpr], suffix='_test')(cos.invars[0].aval) self.assertRaisesRegex( core.JaxprTypeError, r"Variable '.+_test' not defined\n\nin equation:", lambda: core.check_jaxpr(jaxpr))
def test_comparing_var(self): newsym = core.gensym('') a = newsym() b = newsym() c = newsym() assert a < b < c assert c > b > a assert a != b and b != c and a != c
def test_comparing_var(self): newsym = core.gensym() a = newsym(core.abstract_unit) b = newsym(core.abstract_unit) c = newsym(core.abstract_unit) assert a < b < c assert c > b > a assert a != b and b != c and a != c
def test_comparing_var(self): newsym = core.gensym() a = newsym(core.ShapedArray((), np.dtype('int32'))) b = newsym(core.ShapedArray((), np.dtype('int32'))) c = newsym(core.ShapedArray((), np.dtype('int32'))) assert a < b < c assert c > b > a assert a != b and b != c and a != c
def ignore_errors_jaxpr(jaxpr, error): """Constructs a jaxpr which takes two extra args but ignores them.""" err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) consts = jaxpr.consts jaxpr = jaxpr.jaxpr new_vars = core.gensym([jaxpr]) new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars) new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars, jaxpr.eqns) return core.ClosedJaxpr(new_jaxpr, consts)
def _nonzero_typecheck_rule(invar): bound = invar.aval.shape[-1] bound = bound if isinstance(bound, int) else bound._bound newvar = core.gensym() out_dim_var = newvar(AbsArray(invar.aval.shape[:-1], BoundedIntTy(bound))) if len(invar.aval.shape) == 1: out_val_aval = AbsArray((out_dim_var,), BaseType(np.dtype('int32'))) else: indices = tuple(range(len(out_dim_var.aval.shape))) # pytype: disable=attribute-error expr = DimIndexingExpr(out_dim_var, indices) out_val_aval = AbsArray((*invar.aval.shape[:-1], expr), BaseType(np.dtype('int32'))) return out_dim_var, out_val_aval
def refresh_names(jaxpr): vs = {} g = jc.gensym() varmap = lambda v: vs[v] if v in vs else vs.setdefault(v, g(v.aval)) jaxpr_constvars = map(varmap, jaxpr.constvars) jaxpr_invars = map(varmap, jaxpr.invars) new_eqns = [] for eqn in jaxpr.eqns: invars = [ v if isinstance(v, Literal) else varmap(v) for v in eqn.invars ] outvars = map(varmap, eqn.outvars) new_eqns.append( JaxprEqn(invars, outvars, eqn.primitive, eqn.params, eqn.source_info)) jaxpr_outvars = map(varmap, jaxpr.outvars) return Jaxpr(jaxpr_constvars, jaxpr_invars, jaxpr_outvars, new_eqns)
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr): newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_') all_res_vars = _map(newvar, all_res_avals) def augment_jaxpr(jaxpr, res_indices): num_res = len(res_indices) res_vars = jaxpr.jaxpr.invars[:num_res] non_res_vars = jaxpr.jaxpr.invars[num_res:] aug_res_vars = list( util.subvals(all_res_vars, zip(res_indices, res_vars))) aug_invars = aug_res_vars + non_res_vars jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns, jaxpr.jaxpr.effects) jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) return jaxpr_aug return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
def _append_invars(jaxpr, avals): newvar = core.gensym([jaxpr]) return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals), jaxpr.outvars, jaxpr.eqns)
def test_var_compared_by_identity(self): a1 = core.gensym()(core.abstract_unit) a2 = core.gensym()(core.abstract_unit) assert str(a1) == str(a2) assert a1 != a2
def test_var_compared_by_identity(self): a1 = core.gensym()(core.ShapedArray((), np.dtype('int32'))) a2 = core.gensym()(core.ShapedArray((), np.dtype('int32'))) assert str(a1) == str(a2) assert a1 != a2
def test_var_tree_flatten(self): newsym = core.gensym('') a, b, c, d = newsym(), newsym(), newsym(), newsym() syms = {c: d, a: b} assert 'bd' == ''.join(map(str, tree_leaves(syms)))
def test_var_compared_by_identity(self): a1 = core.gensym('')() a2 = core.gensym('')() assert str(a1) == str(a2) assert a1 != a2
def test_var_tree_flatten(self): newsym = core.gensym() aval = core.ShapedArray((), np.dtype('int32')) a, b, c, d = (newsym(aval), newsym(aval), newsym(aval), newsym(aval)) syms = {c: d, a: b} assert 'bd' == ''.join(map(str, tree_leaves(syms)))
def _tracers_to_jaxpr(in_tracers, out_tracers): """Constructs Jaxpr given tracers for inputs and outputs. Copied from jax.interpreters.partial_eval.tracers_to_jaxpr but modified to raise an VariableError when unknown in_tracers are found, rather than the default AssertionError. Args: in_tracers: the tracers that were created for the function inputs out_tracers: the tracers that were output by the function. Returns: a triple of a `Jaxpr`, a list of constant values corresponding to the `constvars` in the returned Jaxps, and a list of environment values. The vars for the environment values have been pre-pended to the Jaxpr's `invars`. Raises: VariableError: if an unknown input tracer is found """ newvar = jax_core.gensym(None) t_to_var = {} def getvar(t): var = t_to_var.get(id(t)) if var is None: var = newvar(t.pval.get_aval()) t_to_var[id(t)] = var return var sorted_tracers = jax_util.toposort(out_tracers) invars = safe_map(getvar, in_tracers) eqns = [] env = {} consts = {} const_to_var = {} def getconstvar(c): var = const_to_var.get(id(c)) if var is None: var = newvar(jax_core.get_aval(c)) const_to_var[id(c)] = var return var processed_eqn_ids = set() for t in sorted_tracers: recipe = t.recipe if isinstance(recipe, pe.JaxprEqnRecipe): if recipe.eqn_id not in processed_eqn_ids: eqns.append(pe.recipe_to_eqn(getvar, recipe)) processed_eqn_ids.add(recipe.eqn_id) elif isinstance(recipe, pe.LambdaBinding): if not any(t is in_tracer for in_tracer in in_tracers): raise VariableError('Found unknown input tracer.') assert in_tracers, 'Lambda binding with no args' elif isinstance(recipe, pe.FreeVar): env[getvar(t)] = recipe.val elif isinstance(recipe, pe.ConstVar): v = t_to_var[id(t)] = getconstvar(recipe.val) consts[v] = recipe.val elif isinstance(recipe, jax_core.Literal): t_to_var[id(t)] = recipe elif recipe is jax_core.unit: t_to_var[id(t)] = jax_core.unitvar else: raise TypeError(recipe) env_vars, env_vals = jax_util.unzip2(env.items()) const_vars, const_vals = jax_util.unzip2(consts.items()) # The env_vars are pre-pended to the invars jaxpr = jax_core.Jaxpr(const_vars, list(it.chain(env_vars, invars)), safe_map(getvar, out_tracers), eqns) return jaxpr, const_vals, env_vals