def inline_calls(jaxpr): new_eqns = [] def inline_call(jaxpr, invars, outvars): inmap = dict(zip(jaxpr.invars, invars)) outmap = dict(zip(jaxpr.outvars, outvars)) for eqn in jaxpr.eqns: new_invars = [ v if isinstance(v, Literal) else inmap.get(v, v) for v in eqn.invars ] new_outvars = [outmap.get(v, v) for v in eqn.outvars] call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive, eqn.params) if call_jaxpr: if not eqn.primitive in {jc.call_p, xla.xla_call_p}: raise NotImplementedError inline_call(call_jaxpr, new_invars, new_outvars) else: new_eqns.append( JaxprEqn(new_invars, new_outvars, eqn.primitive, eqn.params, eqn.source_info)) for eqn in jaxpr.eqns: call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive, eqn.params) if call_jaxpr: if not eqn.primitive in {jc.call_p, xla.xla_call_p}: raise NotImplementedError inline_call(call_jaxpr, eqn.invars, eqn.outvars) else: new_eqns.append(eqn) return Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
def lazy_eval_jaxpr(jaxpr, consts, *args): def read(v): if type(v) in {jc.Literal, Literal_}: return v.val else: return env[v] def write(v, val): env[v] = val env = {} write(jc.unitvar, jc.unit) map(write, jaxpr.constvars, consts) map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive, eqn.params) if call_jaxpr: raise NotImplementedError map(write, eqn.outvars, map(LazyArray, eqn.outvars)) for eqn in jaxpr.eqns: invals = map(read, eqn.invars) outvals = map(read, eqn.outvars) new_eqn = jc.JaxprEqn(invals, outvals, eqn.primitive, eqn.params, eqn.source_info) map(lambda arr: arr.set_eqn(new_eqn), outvals) return map(read, jaxpr.outvars)
def jaxpr_to_expressions(jaxpr: jax_core.Jaxpr) -> Tuple[Expr]: """Converts a JAXpr into a tuple of output `JaxExpression`s. Args: jaxpr: a `jax.core.Jaxpr` to be converted into a tuple of `JaxExpression`s. Returns: A tuple of `JaxExpression`s. """ env = {} def read_env(var: jax_core.Var) -> Any: if isinstance(var, jax_core.Literal): return Literal(var.val) return env[str(var)] def write_env(var: jax_core.Var, val: Any) -> None: if isinstance(var, jax_core.Literal): return env[str(var)] = val const_patterns = jax_util.safe_map( lambda var: JaxVar(str(var), var.aval.shape, var.aval.dtype), jaxpr.constvars) jax_util.safe_map(write_env, jaxpr.constvars, const_patterns) in_patterns = jax_util.safe_map( lambda var: JaxVar(str(var), var.aval.shape, var.aval.dtype), jaxpr.invars) jax_util.safe_map(write_env, jaxpr.invars, in_patterns) for eqn in jaxpr.eqns: operands = tuple(jax_util.safe_map(read_env, eqn.invars)) call_jaxpr, params = jax_core.extract_call_jaxpr( eqn.primitive, eqn.params) if call_jaxpr: call_expression = BoundExpression(jaxpr_to_expressions(call_jaxpr), {}) variable_names = tuple(map(str, call_jaxpr.invars)) out = CallPrimitive(eqn.primitive, operands, call_expression, Params(params), variable_names) else: out = primitive_to_expression(eqn.primitive)(operands, Params(params)) if eqn.primitive.multiple_results: out_parts = [Part(out, i) for i in range(len(eqn.outvars))] jax_util.safe_map(write_env, eqn.outvars, out_parts) else: write_env(eqn.outvars[0], out) return tuple(jax_util.safe_map(read_env, jaxpr.outvars))
def eval_jaxpr_with_state(jaxpr: jax_core.Jaxpr, rules: Rules, consts: Sequence[Value], state: Value, *args: Value) -> Tuple[List[Value], Value]: """Interprets a JAXpr and manages an input state with primitive rules. The implementation follows `jax.core.eval_jaxpr` closely; the main differences are: 1. Rather than always calling `primitive.bind`, `eval_jaxpr_with_state` looks up a rule for the primitive in the provided `rules` dictionary first. 2. A `state` value is provided that is threaded through the execution of the JAXpr and whose final value is returned as an additional output. Args: jaxpr: The JAXpr to be interpreted. rules: A `dict` that maps JAX primitives to functions that take in `(state, *args)` and return `(output, new_state)`. consts: A list of constant values corresponding to the JAXpr's constvars. state: The initial state for the interpreter. *args: A list of values that correspond to the JAXpr's invars. Returns: A list of outputs from the JAXpr and the final state. """ env = Environment() jax_util.safe_map(env.write, jaxpr.constvars, consts) jax_util.safe_map(env.write, jaxpr.invars, args) for eqn in jaxpr.eqns: invals = jax_util.safe_map(env.read, eqn.invars) call_jaxpr, params = jax_core.extract_call_jaxpr( eqn.primitive, eqn.params) if call_jaxpr: call_rule = _effect_handler_call_rules.get( eqn.primitive, functools.partial(default_call_interpreter_rule, eqn.primitive)) ans, state = call_rule(rules, state, invals, call_jaxpr, **params) elif eqn.primitive in rules: ans, state = rules[eqn.primitive](state, *invals, **params) else: ans = eqn.primitive.bind(*invals, **params) if eqn.primitive.multiple_results: jax_util.safe_map(env.write, eqn.outvars, ans) else: env.write(eqn.outvars[0], ans) return jax_util.safe_map(env.read, jaxpr.outvars), state
def eval_jaxpr_with_kwargs(jaxpr: jax_core.Jaxpr, consts: Iterable[Any], *args, **kwargs): """Evals a jaxpr while passing kwargs into registered primitives.""" def read(v): if isinstance(v, jax_core.Literal): return v.val else: return env[v] def write(v, val): env[v] = val env = {} write(jax_core.unitvar, jax_core.unit) safe_map(write, jaxpr.constvars, consts) safe_map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: in_vals = safe_map(read, eqn.invars) subjaxpr, params = jax_core.extract_call_jaxpr(eqn.primitive, eqn.params) if subjaxpr: subfuns = [ lu.wrap_init( jax_core.partial(eval_jaxpr_with_kwargs, subjaxpr, (), **kwargs)) ] else: subfuns = [] params = dict(eqn.params) if eqn.primitive in kwargs_rules: new_kwargs = dict(params.pop('kwargs', {}), **kwargs) ans = kwargs_rules[eqn.primitive](*(subfuns + in_vals), kwargs=new_kwargs, **params) else: ans = eqn.primitive.bind(*(subfuns + in_vals), **params) if eqn.primitive.multiple_results: safe_map(write, eqn.outvars, ans) else: write(eqn.outvars[0], ans) return safe_map(read, jaxpr.outvars)
def propagate(cell_type: Type[Cell], rules: Dict[jax_core.Primitive, PropagationRule], jaxpr: pe.Jaxpr, constcells: List[Cell], incells: List[Cell], outcells: List[Cell]) -> Environment: """Propagates cells in a Jaxpr using a set of rules. Args: cell_type: used to instantiate literals into cells rules: maps JAX primitives to propagation rule functions jaxpr: used to construct the propagation graph constcells: used to populate the Jaxpr's constvars incells: used to populate the Jaxpr's invars outcells: used to populate the Jaxpr's outcells Returns: The Jaxpr environment after propagation has terminated """ env = Environment(cell_type, jaxpr) safe_map(env.write, jaxpr.constvars, constcells) safe_map(env.write, jaxpr.invars, incells) safe_map(env.write, jaxpr.outvars, outcells) eqns = safe_map(Equation.from_jaxpr_eqn, jaxpr.eqns) get_neighbor_eqns = construct_graph_representation(eqns) # Initialize propagation queue with equations neighboring constvars, invars, # and outvars. out_eqns = set() for eqn in jaxpr.eqns: for var in it.chain(eqn.invars, eqn.outvars): env.write(var, cell_type.unknown(var.aval)) for var in it.chain(jaxpr.outvars, jaxpr.invars, jaxpr.constvars): out_eqns.update(get_neighbor_eqns(var)) queue = collections.deque(out_eqns) while queue: eqn = queue.popleft() incells = safe_map(env.read, eqn.invars) outcells = safe_map(env.read, eqn.outvars) rule = rules[eqn.primitive] call_jaxpr, params = jax_core.extract_call_jaxpr( eqn.primitive, eqn.params) if call_jaxpr: subfuns = [ lu.wrap_init( functools.partial(propagate, cell_type, rules, call_jaxpr, ())) ] else: subfuns = [] new_incells, new_outcells, subenv = rule(subfuns + incells, outcells, **params) if subenv: env.write_subenv(eqn, subenv) new_incells = safe_map(env.write, eqn.invars, new_incells) new_outcells = safe_map(env.write, eqn.outvars, new_outcells) update_queue_state(queue, eqn, get_neighbor_eqns, incells, outcells, new_incells, new_outcells) return env
def propagate(cell_type: Type[Cell], rules: Dict[jax_core.Primitive, PropagationRule], jaxpr: pe.Jaxpr, constcells: List[Cell], incells: List[Cell], outcells: List[Cell], reducer: Callable[[Environment, Equation, State, State], State] = identity_reducer, initial_state: State = None) -> Tuple[Environment, State]: """Propagates cells in a Jaxpr using a set of rules. Args: cell_type: used to instantiate literals into cells rules: maps JAX primitives to propagation rule functions jaxpr: used to construct the propagation graph constcells: used to populate the Jaxpr's constvars incells: used to populate the Jaxpr's invars outcells: used to populate the Jaxpr's outcells reducer: An optional callable used to reduce over the state at each equation in the Jaxpr. `reducer` takes in `(env, eqn, state, new_state)` as arguments and should return an updated state. The `new_state` value is provided by each equation. initial_state: The initial `state` value used in the reducer Returns: The Jaxpr environment after propagation has terminated """ env = Environment(cell_type, jaxpr) safe_map(env.write, jaxpr.constvars, constcells) safe_map(env.write, jaxpr.invars, incells) safe_map(env.write, jaxpr.outvars, outcells) eqns = safe_map(Equation.from_jaxpr_eqn, jaxpr.eqns) get_neighbor_eqns = construct_graph_representation(eqns) # Initialize propagation queue with equations neighboring constvars, invars, # and outvars. out_eqns = set() for eqn in jaxpr.eqns: for var in it.chain(eqn.invars, eqn.outvars): env.write(var, cell_type.unknown(var.aval)) for var in it.chain(jaxpr.outvars, jaxpr.invars, jaxpr.constvars): out_eqns.update(get_neighbor_eqns(var)) queue = collections.deque(out_eqns) while queue: eqn = queue.popleft() incells = safe_map(env.read, eqn.invars) outcells = safe_map(env.read, eqn.outvars) call_jaxpr, params = jax_core.extract_call_jaxpr( eqn.primitive, eqn.params) if call_jaxpr: subfuns = [ lu.wrap_init( functools.partial(propagate, cell_type, rules, call_jaxpr, (), initial_state=initial_state, reducer=reducer)) ] if eqn.primitive not in rules: rule = default_call_rules.get(eqn.primitive) else: rule = rules[eqn.primitive] else: subfuns = [] rule = rules[eqn.primitive] new_incells, new_outcells, eqn_state = rule(subfuns + incells, outcells, **params) env.write_state(eqn, eqn_state) new_incells = safe_map(env.write, eqn.invars, new_incells) new_outcells = safe_map(env.write, eqn.outvars, new_outcells) update_queue_state(queue, eqn, get_neighbor_eqns, incells, outcells, new_incells, new_outcells) state = initial_state for eqn in eqns: state = reducer(env, eqn, state, env.read_state(eqn)) return env, state
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) def write_cotangent(prim, v, ct): # assert v not in primal_env assert ct is not Zero, (prim, v.aval ) # check for an old harmless type error if ct is None or type(v) is Literal: return if type(ct) is Zero: # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return axes_to_reduce = tuple(axis_name for axis_name in reduce_axes if axis_name in core.get_aval(ct).named_shape and axis_name not in v.aval.named_shape) if axes_to_reduce: ct = jax.lax.psum(ct, axis_name=axes_to_reduce) ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct if config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join( v.aval, ct_aval).strip_weak_type().strip_named_shape() assert v.aval.strip_weak_type().strip_named_shape( ) == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.pop(v, Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, UndefinedPrimal(v.aval)) def write_primal(v, val): if not is_undefined_primal(val): primal_env[v] = val primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.constvars, consts) # FIXME: invars can contain both primal and tangent values, and this line # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) ct_env: Dict[Any, Any] = {} map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: # FIXME: Some invars correspond to tangents invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) else: cts_in, = map(read_cotangent, eqn.outvars) with source_info_util.user_context(eqn.source_info.traceback): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] call_jaxpr, params = core.extract_call_jaxpr( eqn.primitive, eqn.params) cts_out = get_primitive_transpose( eqn.primitive)(params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes) elif eqn.primitive in reducing_transposes: cts_out = reducing_transposes[eqn.primitive](reduce_axes, cts_in, *invals, **eqn.params) else: cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params) cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out
def submerge_consts(jaxpr, consts, invals=None): """ Replace constvars with literals in jaxpr and its sub-jaxprs. """ # TODO(j-towns): check that consts are in jax.core.literalable_types consts = dict(zip(jaxpr.constvars, consts)) if invals is not None: # We're in a call_jaxpr new_jaxpr_invars = [] for var, val in zip(jaxpr.invars, invals): if isinstance(val, Var): new_jaxpr_invars.append(var) else: consts[var] = val else: new_jaxpr_invars = jaxpr.invars new_eqns = [] for eqn in jaxpr.eqns: if all( isinstance(var, Literal) or var in consts for var in eqn.invars): # Perform constant folding if all inputs to an eqn are known in_vals = [ var.val if isinstance(var, Literal) else consts[var] for var in eqn.invars ] call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive, eqn.params) if call_jaxpr: subfuns = [ lu.wrap_init(partial(jc.eval_jaxpr, call_jaxpr, ())) ] else: subfuns = [] ans = eqn.primitive.bind(*(subfuns + in_vals), **params) if eqn.primitive.multiple_results: for outvar, out in zip(eqn.outvars, ans): consts[outvar] = out else: outvar, = eqn.outvars consts[outvar] = ans else: new_invars = [ consts[var] if (isinstance(var, Var) and var in consts) else var for var in eqn.invars ] new_params = dict(eqn.params) if eqn.primitive.call_primitive or eqn.primitive.map_primitive: new_params['call_jaxpr'] = submerge_consts( eqn.params['call_jaxpr'], [], new_invars) new_invars = [ var for var in new_invars if isinstance(var, Var) ] else: new_invars = [ var if isinstance(var, (Var, Literal)) else Literal_(var) for var in new_invars ] new_eqns.append( JaxprEqn(invars=new_invars, outvars=eqn.outvars, primitive=eqn.primitive, params=new_params, source_info=eqn.source_info)) return Jaxpr([], new_jaxpr_invars, jaxpr.outvars, new_eqns)