def vjp_func(tangents): all_tangents = aux_vjp(tangents) tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:] inputs_tangents = jax.tree_flatten(inputs_tangents)[0] tangents_dict.update(zip(jaxpr.invars, inputs_tangents)) read_primals = functools.partial(tgm.read_env, primals_dict) read_tangents = functools.partial(tgm.read_env, tangents_dict) layers_info = [] for jaxpr_eqn in layer_tags: layer_tag = _unbox_layer_tag(jaxpr_eqn) info = dict() primals = jax_util.safe_map(read_primals, tuple(jaxpr_eqn.invars)) ( info["outputs"], info["inputs"], info["params"], ) = layer_tag.split_all_inputs(primals) tangents = jax_util.safe_map(read_tangents, tuple(jaxpr_eqn.invars)) ( info["outputs_tangent"], info["inputs_tangent"], info["params_tangent"], ) = layer_tag.split_all_inputs(tangents) layers_info.append(info) return tuple(layers_info)
def forward_compute_losses( params_primals: Any, ) -> Sequence[Sequence[jnp.ndarray]]: primals_[params_index] = params_primals flat_args = jax.tree_flatten(primals_)[0] # Mapping from variable -> value env = dict() read = functools.partial(tgm.read_env, env) write = functools.partial(tgm.write_env, env) # Bind args and consts to environment write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, jaxpr.invars, flat_args) jax_util.safe_map(write, jaxpr.constvars, consts) # Loop through equations and evaluate primitives using `bind` losses_so_far = 0 loss_tags = [] for eqn in jaxpr.eqns: tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) if isinstance(eqn.primitive, tags.LossTag): loss_tags.append(eqn) losses_so_far += 1 if num_losses is not None and losses_so_far == num_losses: break return tuple(tuple(read(v) for v in tag.invars) for tag in loss_tags)
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) out_pvals = [t.pval for t in out_tracers] # TODO: this is from partial_eval.trace_to_jaxpr. Share. assert not env # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr out_avals = safe_map(abstract_arrays.raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple( abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts) in_pvals = [t.pval for t in in_tracers] in_avals = tuple( safe_map(abstract_arrays.raise_to_shaped, unzip2(in_pvals)[0])) typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) assert not env # TODO: this is from partial_eval.trace_to_jaxpr. Share. closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) return closed_jaxpr, consts
def plant_function(main: jax_core.MainTrace, settings: HarvestSettings, in_tree: Any, args: Iterable[Any]): """A function transformation that injects values in place of sows.""" trace = HarvestTrace(main, jax_core.cur_sublevel()) plants, args = tree_util.tree_unflatten(in_tree, args) args = jax_util.safe_map(trace.pure, args) context = PlantContext(settings, plants) with trace_util.new_dynamic_context(main, context): ans = yield args, {} out_tracers = jax_util.safe_map(trace.full_raise, ans) del main yield [t.val for t in out_tracers]
def jaxpr_to_expressions(jaxpr: jax_core.Jaxpr) -> Tuple[Expr]: """Converts a JAXpr into a tuple of output `JaxExpression`s. Args: jaxpr: a `jax.core.Jaxpr` to be converted into a tuple of `JaxExpression`s. Returns: A tuple of `JaxExpression`s. """ env = {} def read_env(var: jax_core.Var) -> Any: if isinstance(var, jax_core.Literal): return Literal(var.val) return env[str(var)] def write_env(var: jax_core.Var, val: Any) -> None: if isinstance(var, jax_core.Literal): return env[str(var)] = val const_patterns = jax_util.safe_map( lambda var: JaxVar(str(var), var.aval.shape, var.aval.dtype), jaxpr.constvars) jax_util.safe_map(write_env, jaxpr.constvars, const_patterns) in_patterns = jax_util.safe_map( lambda var: JaxVar(str(var), var.aval.shape, var.aval.dtype), jaxpr.invars) jax_util.safe_map(write_env, jaxpr.invars, in_patterns) for eqn in jaxpr.eqns: operands = tuple(jax_util.safe_map(read_env, eqn.invars)) call_jaxpr, params = jax_core.extract_call_jaxpr( eqn.primitive, eqn.params) if call_jaxpr: call_expression = BoundExpression(jaxpr_to_expressions(call_jaxpr), {}) variable_names = tuple(map(str, call_jaxpr.invars)) out = CallPrimitive(eqn.primitive, operands, call_expression, Params(params), variable_names) else: out = primitive_to_expression(eqn.primitive)(operands, Params(params)) if eqn.primitive.multiple_results: out_parts = [Part(out, i) for i in range(len(eqn.outvars))] jax_util.safe_map(write_env, eqn.outvars, out_parts) else: write_env(eqn.outvars[0], out) return tuple(jax_util.safe_map(read_env, jaxpr.outvars))
def default_process_primitive( self, primitive: jax_core.Primitive, tracers: List['HarvestTracer'], params: Dict[str, Any]) -> Union['HarvestTracer', List['HarvestTracer']]: context = trace_util.get_dynamic_context(self) vals = [t.val for t in tracers] if primitive is sow_p: outvals = context.process_sow(*vals, **params) return jax_util.safe_map(self.pure, outvals) outvals = primitive.bind(*vals, **params) if not primitive.multiple_results: outvals = [outvals] out_tracers = jax_util.safe_map(self.pure, outvals) if primitive.multiple_results: return out_tracers return out_tracers[0]
def eval_sparse( jaxpr: core.Jaxpr, consts: Sequence[Array], # all consts are dense argspecs: Sequence[ArgSpec], # mix of sparse and dense pointers into spenv spenv: SparseEnv, ) -> Sequence[ArgSpec]: env: Dict[core.Var, ArgSpec] = {} def read(var: core.Var) -> Union[Array, ArgSpec]: # all literals are dense if isinstance(var, core.Literal): return ArgSpec(np.shape(var.val), spenv.push(var.val), None) else: return env[var] def write_buffer(var: core.Var, a: Array) -> None: if var is core.dropvar: return env[var] = ArgSpec(a.shape, spenv.push(a), None) def write(var: core.Var, a: ArgSpec) -> None: if var is core.dropvar: return env[var] = a # TODO: handle unitvar at all? #write_buffer(core.unitvar, core.unit) safe_map(write_buffer, jaxpr.constvars, consts) safe_map(write, jaxpr.invars, argspecs) for eqn in jaxpr.eqns: prim = eqn.primitive invals = safe_map(read, eqn.invars) if any(val.is_sparse() for val in invals): if prim not in sparse_rules: raise NotImplementedError(f"sparse rule for {prim}") out = sparse_rules[prim](spenv, *invals, **eqn.params) else: if prim is xla.xla_call_p: # TODO(vanderplas,frostig): workaround for binding call primitives # within a jaxpr interpreter params = eqn.params.copy() fun = lu.wrap_init( core.jaxpr_as_fun( pe.ClosedJaxpr(params.pop('call_jaxpr'), ()))) out_bufs = prim.bind(fun, *(val.data(spenv) for val in invals), **params) else: out_bufs = prim.bind(*(val.data(spenv) for val in invals), **eqn.params) out_bufs = out_bufs if prim.multiple_results else [out_bufs] out = [] for buf in out_bufs: out.append(ArgSpec(buf.shape, spenv.push(buf), None)) safe_map(write, eqn.outvars, out) return safe_map(read, jaxpr.outvars)
def _plant_cond_rule(trace, *tracers, branches, linear): """Injects the same values into both branches of a conditional.""" index_tracer, ops_tracers = tracers[0], tracers[1:] index_val, ops_vals = tree_util.tree_map(lambda x: x.val, (index_tracer, ops_tracers)) ops_avals = tree_util.tree_map(lambda x: x.aval, ops_tracers) context = trace_util.get_dynamic_context(trace) settings = context.settings plant_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) branch_metadatas = tuple( _get_harvest_metadata(branch, settings, *ops_tracers) for branch in branches) _check_branch_metadata(branch_metadatas) plants = context.plants branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) planted_branches = tuple( functools.partial(plant(f, **plant_settings), plants) for f in branch_funs) in_tree = tree_util.tree_structure(ops_avals) new_branch_jaxprs, consts, _ = ( lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access planted_branches, in_tree, ops_avals, lax.cond_p.name)) out = lax.cond_p.bind(index_val, *(tuple(consts) + ops_vals), branches=tuple(new_branch_jaxprs), linear=(False, ) * len(tuple(consts) + linear)) return jax_util.safe_map(trace.pure, out)
def _reap_cond_rule(trace, *tracers, branches, linear): """Reaps each path of the `cond`.""" index_tracer, ops_tracers = tracers[0], tracers[1:] index_val, ops_vals = tree_util.tree_map(lambda x: x.val, (index_tracer, ops_tracers)) _, ops_avals = tree_util.tree_map(lambda x: x.aval, (index_tracer, ops_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings reap_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) branch_metadatas = tuple( _get_harvest_metadata(branch, settings, *ops_tracers) for branch in branches) _check_branch_metadata(branch_metadatas) branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) reaped_branches = tuple( call_and_reap(f, **reap_settings) for f in branch_funs) in_tree = tree_util.tree_structure(ops_avals) new_branch_jaxprs, consts, out_trees = ( lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access reaped_branches, in_tree, ops_avals, lax.cond_p.name)) out = lax.cond_p.bind(index_val, *(tuple(consts) + ops_vals), branches=tuple(new_branch_jaxprs), linear=(False, ) * len(tuple(consts) + linear)) out = jax_util.safe_map(trace.pure, out) out, reaps = tree_util.tree_unflatten(out_trees[0], out) for k, v in reaps.items(): sow(v, name=k, tag=settings.tag, mode=branch_metadatas[0][k]['mode']) return out
def process_primitive(self, primitive: core.Primitive, tracers: Sequence[TensorFlowTracer], params) -> TensorFlowTracer: impl = self.get_primitive_impl(primitive) args_tf: Sequence[TfValOrUnit] = [t.val for t in tracers] # impl takes core.unit and returns core.unit when needed. val_out: TfValOrUnit = impl(*args_tf, **params) if primitive.multiple_results: out = util.safe_map(functools.partial(TensorFlowTracer, self), val_out) # type: ignore else: out = TensorFlowTracer(self, val_out) # Check that the impl rule returned a value of expected shape and dtype if not core.skip_checks: expected_out_aval: core.AbstractValue = primitive.abstract_eval( *[t.aval for t in tracers], **params) if primitive.multiple_results: for o, expected_aval in zip(out, expected_out_aval): # type: ignore assert o.aval == expected_aval, ( f"{primitive}: out.aval = {o.aval}; expected {expected_aval}" ) else: assert out.aval == expected_out_aval, ( # type: ignore f"{primitive}: out.aval = {out.aval}; expected {expected_out_aval}" ) # type: ignore return out # type: ignore
def eval_sparse( jaxpr: core.Jaxpr, consts: Sequence[Array], # all consts are dense spvalues: Sequence[SparsifyValue], # mix of sparse and dense pointers into spenv spenv: SparsifyEnv, ) -> Sequence[SparsifyValue]: env : Dict[core.Var, SparsifyValue] = {} def read(var: core.Var) -> Union[Array, SparsifyValue]: # all literals are dense if isinstance(var, core.Literal): return spenv.dense(var.val) else: return env[var] def write_buffer(var: core.Var, a: Array) -> None: if isinstance(var, core.DropVar): return env[var] = spenv.dense(a) def write(var: core.Var, a: SparsifyValue) -> None: if isinstance(var, core.DropVar): return assert a is not None env[var] = a safe_map(write_buffer, jaxpr.constvars, consts) safe_map(write, jaxpr.invars, spvalues) for eqn in jaxpr.eqns: prim = eqn.primitive invals = safe_map(read, eqn.invars) if any(val.is_sparse() for val in invals): if prim not in sparse_rules: _raise_unimplemented_primitive(prim) out = sparse_rules[prim](spenv, *invals, **eqn.params) else: if prim is xla.xla_call_p: # TODO(vanderplas,frostig): workaround for binding call primitives # within a jaxpr interpreter params = eqn.params.copy() fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ()))) out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params) else: out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params) out_bufs = out_bufs if prim.multiple_results else [out_bufs] out = [] for buf, outvar in safe_zip(out_bufs, eqn.outvars): if isinstance(outvar, core.DropVar): out.append(None) else: out.append(spenv.dense(buf)) safe_map(write, eqn.outvars, out) return safe_map(read, jaxpr.outvars)
def evaluate_eqn(eqn, in_values, write_func): """Evaluate a single Jax equation and writes the outputs.""" in_values = list(in_values) # This is logic specifically to handle `xla_call` call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params) if call_jaxpr: subfuns = [ jax.core.lu.wrap_init( functools.partial(jax.core.eval_jaxpr, call_jaxpr, ())) ] else: subfuns = [] ans = eqn.primitive.bind(*(subfuns + in_values), **params) if eqn.primitive.multiple_results: jax_util.safe_map(write_func, eqn.outvars, ans) else: write_func(eqn.outvars[0], ans) return ans
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): """Reaps the body of a while loop to get the reaps of the final iteration.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) _, init_avals = tree_util.tree_map(lambda x: x.aval, (body_const_tracers, init_tracers)) cond_const_vals, body_const_vals, init_vals = tree_util.tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) for k, meta in body_metadata.items(): mode = meta['mode'] if mode != 'clobber': raise ValueError( f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') reap_avals = {k: v['aval'] for k, v in body_metadata.items()} cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) reap_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) def new_cond(carry, _): return cond_fun(*(cond_const_vals + carry)) def new_body(carry, _): carry, reaps = call_and_reap( body_fun, **reap_settings)(*(body_const_vals + carry)) return (carry, reaps) new_in_avals, new_in_tree = tree_util.tree_flatten( (init_avals, reap_avals)) new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_cond, new_in_tree, tuple(new_in_avals)) new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, new_in_tree, tuple(new_in_avals)) dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), reap_avals) new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals)) out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals), cond_nconsts=len(cond_consts), body_nconsts=len(body_consts), cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) out = jax_util.safe_map(trace.pure, out) out, reaps = tree_util.tree_unflatten(out_tree, out) for k, v in reaps.items(): sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode']) return out
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_closed_jaxpr, body_const_vals): # Trace the conditional function. cond_func takes 0 arguments, but # for lax.while we need a conditional function that takes the # carried_state_names. _initial_style_jaxpr will start its own trace and # will create tracers for all the carried state. We must put these values # in the scope._mutable_state before we trace the conditional # function. def cond_func_wrapped(*args): assert len(args) == len(carried_state_names) for ms, init_ms in zip(carried_state_names, args): scope._mutable_state[ms] = init_ms res = self.cond_func() # Conditional function is not allowed to modify the scope state for ms, init_ms in zip(carried_state_names, args): if not (scope._mutable_state[ms] is init_ms): raise ValueError( f"Conditional function modifies scope.{ms} field.") return res init_avals = safe_map(_BodyTracer.abstractify, init_vals) cond_jaxpr, cond_consts, cond_tree = ( lax_control_flow._initial_style_jaxpr(cond_func_wrapped, carried_tree, tuple(init_avals))) # TODO: share these checks with lax_control_flow.while if not tree_util.treedef_is_leaf(cond_tree): raise TypeError( f"cond_fun must return a boolean scalar, but got pytree {cond_tree}." ) if not safe_map(core.typecompat, cond_jaxpr.out_avals, [core.ShapedArray((), np.bool_)]): raise TypeError( f"cond_fun must return a boolean scalar, but got output type(s) " f"{cond_jaxpr.out_avals}.") return lax_control_flow.while_p.bind(*itertools.chain( cond_consts, body_const_vals, init_vals), cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_const_vals), body_jaxpr=body_closed_jaxpr)
def new_body_tf_func(pred_b: TfVal, *carry: TfVal) -> Sequence[TfVal]: new_carry = _interpret_jaxpr(body_jaxpr, *body_consts, *carry) def select_one_carry(new_c, c): pred_b_bcast = _broadcast_in_dim(pred_b, new_c.shape, list(range(len(pred_b.shape)))) return tf.where(pred_b_bcast, new_c, c) selected_carry = list(util.safe_map(select_one_carry, new_carry, carry)) next_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *selected_carry) return (next_pred_b, *selected_carry)
def reap_function(main: jax_core.MainTrace, settings: HarvestSettings, return_metadata: bool, args: Iterable[Any]): """A function transformation that returns reap values.""" trace = HarvestTrace(main, jax_core.cur_sublevel()) in_tracers = jax_util.safe_map(trace.pure, args) context = ReapContext(settings, {}) with trace_util.new_dynamic_context(main, context): ans = yield in_tracers, {} out_tracers = jax_util.safe_map(trace.full_raise, ans) reap_tracers = tree_util.tree_map(lambda x: trace.full_raise(x.value), context.reaps) reap_metadata = tree_util.tree_map(lambda x: x.metadata, context.reaps) del main out_values, reap_values = tree_util.tree_map(lambda x: x.val, (out_tracers, reap_tracers)) if return_metadata: out = (out_values, reap_values, reap_metadata) else: out = (out_values, reap_values) yield out
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_typed_jaxpr, body_const_vals): # Simulate a pass-through false branch init_avals = safe_map(_BodyTracer.abstractify, init_vals) false_body_typed_jaxpr, false_body_const_vals, _ = ( lax_control_flow._initial_style_jaxpr(lambda *args: args, carried_tree, tuple(init_avals))) return lax_control_flow.cond_p.bind( *itertools.chain([self.pred], body_const_vals, init_vals, false_body_const_vals, init_vals), true_jaxpr=body_typed_jaxpr, false_jaxpr=false_body_typed_jaxpr)
def jaxpr_const_maker(*args, **kwargs): # Set up fun for transformation wrapped = lu.wrap_init(fun) # Flatten input args jax_args, in_tree = tree_util.tree_flatten((args, kwargs)) # Transform fun to accept flat args and return a flat list result jaxtree_fun, out_tree = api_util.flatten_fun(wrapped, in_tree) # Abstract and partial-val's flat args pvals = safe_map(pv_like, jax_args) # Trace function into Jaxpr jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals) return jaxpr, consts
def _unravel_array_into_pytree(pytree, axis, arr, is_leaf=None): leaves, treedef = tree_flatten(pytree, is_leaf=is_leaf) axis = axis % arr.ndim shapes = [ arr.shape[:axis] + np.shape(l) + arr.shape[axis + 1:] for l in leaves ] parts = arr.split(np.cumsum(safe_map(np.size, leaves[:-1])), axis) reshaped_parts = [x.reshape(shape) for x, shape in zip(parts, shapes)] return tree_unflatten( treedef, reshaped_parts, )
def _tfval_add_unit(vals: Sequence[TfValOrUnit], avals: Sequence[core.AbstractValue]) -> Sequence[TfValOrUnit]: """Turn regular TfVals into TfValOrUnit, based on expected abstract values. This function is sometimes called with a mix of core.unit and tf.nan in places of units. """ def add_unit(v: TfValOrUnit, aval: core.AbstractValue): if not core.skip_checks: assert ((v is core.unit or tf.math.is_nan(v)) if aval is core.abstract_unit else _is_tfval(v)) return core.unit if aval is core.abstract_unit else v return util.safe_map(add_unit, vals, avals)
def wrapped_auto_registered(*args): flat_args, _ = jax.tree_flatten(args) # Mapping from variable -> value env = {} read = functools.partial(read_env, env) write = functools.partial(write_env, env) def tag(var): if matches.get(var) is not None: inv_map, tagging_func = matches[var] var_map = { k: v for k, v in inv_map.items() if not isinstance(k, str) } val_map = jax.tree_map(read, var_map) val = tagging_func(inv_map, val_map) env[var] = val # Bind args and consts to environment write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, graph.jaxpr.invars, flat_args) jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts) # Register any orphan parameters as generic for param_var in orphan_params: write(param_var, tags.register_generic(read(param_var))) # Set the correct output variables if compute_only_loss_tags: output_vars = loss_output_vars out_tree = jax.tree_structure(loss_output_vars) else: output_vars = graph.jaxpr.outvars out_tree = graph.out_tree # Loop through equations and evaluate primitives using `bind` losses_evaluated = 0 for eqn in graph.jaxpr.eqns: evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) jax_util.safe_map(tag, eqn.outvars) # If we want to output only tagged losses if isinstance(eqn.primitive, tags.LossTag): losses_evaluated += 1 if compute_only_loss_tags and num_losses == losses_evaluated: break outputs = jax_util.safe_map(read, output_vars) return jax.tree_unflatten(out_tree, outputs)
def forward(): own_func_args = func_args # Mapping from variable -> value env = dict() read = functools.partial(tgm.read_env, env) write = functools.partial(tgm.write_env, env) # Bind args and consts to environment write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) jax_util.safe_map(write, jaxpr.constvars, consts) # Loop through equations and evaluate primitives using `bind` num_losses_passed = 0 for eqn in jaxpr.eqns: tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) if isinstance(eqn.primitive, tags.LossTag): num_losses_passed += 1 if num_losses_passed == len(loss_tags): break if num_losses_passed != len(loss_tags): raise ValueError("This should be unreachable.") return jax_util.safe_map(read, layer_input_vars)
def forward_aux(aux): own_func_args = func_args # Mapping from variable -> value env = dict() read = functools.partial(tgm.read_env, env) def write(var, val): if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)): val = val + aux[var] if var in aux else val env[var] = val # Bind args and consts to environment write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) jax_util.safe_map(write, jaxpr.constvars, consts) # Loop through equations and evaluate primitives using `bind` num_losses_passed = 0 losses_inputs_values = [] losses_kwargs_values = [] for eqn in jaxpr.eqns: input_values = jax_util.safe_map(read, eqn.invars) tgm.evaluate_eqn(eqn, input_values, write) if isinstance(eqn.primitive, tags.LossTag): loss = eqn.primitive.loss(*input_values, weight=eqn.params["weight"]) losses_inputs_values.append(loss.inputs) losses_kwargs_values.append( dict(targets=loss.targets, weight=eqn.params["weight"])) num_losses_passed += 1 if num_losses_passed == len(loss_tags): break if num_losses_passed != len(loss_tags): raise ValueError("This should be unreachable.") # Read the inputs to the loss functions, but also return the target values return tuple(losses_inputs_values), tuple(losses_kwargs_values)
def eval_jaxpr_with_state(jaxpr: jax_core.Jaxpr, rules: Rules, consts: Sequence[Value], state: Value, *args: Value) -> Tuple[List[Value], Value]: """Interprets a JAXpr and manages an input state with primitive rules. The implementation follows `jax.core.eval_jaxpr` closely; the main differences are: 1. Rather than always calling `primitive.bind`, `eval_jaxpr_with_state` looks up a rule for the primitive in the provided `rules` dictionary first. 2. A `state` value is provided that is threaded through the execution of the JAXpr and whose final value is returned as an additional output. Args: jaxpr: The JAXpr to be interpreted. rules: A `dict` that maps JAX primitives to functions that take in `(state, *args)` and return `(output, new_state)`. consts: A list of constant values corresponding to the JAXpr's constvars. state: The initial state for the interpreter. *args: A list of values that correspond to the JAXpr's invars. Returns: A list of outputs from the JAXpr and the final state. """ env = Environment() jax_util.safe_map(env.write, jaxpr.constvars, consts) jax_util.safe_map(env.write, jaxpr.invars, args) for eqn in jaxpr.eqns: invals = jax_util.safe_map(env.read, eqn.invars) call_jaxpr, params = jax_core.extract_call_jaxpr( eqn.primitive, eqn.params) if call_jaxpr: call_rule = _effect_handler_call_rules.get( eqn.primitive, functools.partial(default_call_interpreter_rule, eqn.primitive)) ans, state = call_rule(rules, state, invals, call_jaxpr, **params) elif eqn.primitive in rules: ans, state = rules[eqn.primitive](state, *invals, **params) else: ans = eqn.primitive.bind(*invals, **params) if eqn.primitive.multiple_results: jax_util.safe_map(env.write, eqn.outvars, ans) else: env.write(eqn.outvars[0], ans) return jax_util.safe_map(env.read, jaxpr.outvars), state
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2): """Raises TypeError if (tree1, avals1) does not match (tree2, avals2). Corresponding `tree` and `avals` must match in the sense that the number of leaves in `tree` must be equal to the length of `avals`. `what` will be prepended to details of the mismatch in TypeError. """ if tree1 != tree2: msg = ("{} must have same type structure, got {} and {}.") raise TypeError(msg.format(what, tree1, tree2)) if not all(safe_map(typematch, avals1, avals2)): msg = ("{} must have identical types, " "got {} and {}.") raise TypeError(msg.format(what, tree_unflatten(tree1, avals1), tree_unflatten(tree2, avals2)))
def _get_harvest_metadata(closed_jaxpr, settings, *args): """Probes a jaxpr for metadata like its sown values.""" fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr)) with jax_core.new_main(HarvestTrace) as main: settings = HarvestSettings(settings.tag, settings.blocklist, settings.allowlist, True) fun = reap_function(fun, main, settings, True) fun, aux = _reap_metadata_wrapper(fun) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) in_avals = jax_util.safe_map( lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)), flat_args) pe.trace_to_jaxpr_final(flat_fun, in_avals) metadata = aux() out_tree() return metadata
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_closed_jaxpr, body_const_vals): # Simulate a pass-through false branch in_vals, in_tree = tree_util.tree_flatten( (body_const_vals, tree_util.tree_unflatten(carried_tree, init_vals))) in_avals = safe_map(_BodyTracer.abstractify, in_vals) pass_through_closed_jaxpr, pass_through_const_vals, _ = ( lax_control_flow._initial_style_jaxpr( lambda *args: args[1], in_tree, tuple(in_avals))) assert len(pass_through_const_vals) == 0 args = list(itertools.chain(body_const_vals, init_vals)) return lax_control_flow.cond_p.bind( self.index, *args, branches=(pass_through_closed_jaxpr, body_closed_jaxpr), linear=(False,) * len(args))
def _get_jax_objects(function, args, kwargs): # Set up function for transformation wrapped_function = j_linear_util.wrap_init(function) # Flatten input arguments jax_arguments, in_tree = j_tree_util.tree_flatten((args, kwargs)) # Transform function to accept flat arguments # and return a flat list result jaxtree_function, _ = j_api_util.flatten_fun(wrapped_function, in_tree) # Abstract and partial-value's flat arguments partial_values = j_util.safe_map(_get_partial_value, jax_arguments) # Trace function into Jaxpr jaxpr, _, constants = ji_partial_eval.trace_to_jaxpr( jaxtree_function, partial_values ) result = (jaxpr, constants) return result
def _tfval_add_unit(vals: Sequence[TfValOrUnit], avals: Sequence[core.AbstractValue]) -> Sequence[TfValOrUnit]: """Turn regular TfVals into TfValOrUnit, based on expected abstract values. When the aval is a unit, the corresponding value is either core.unit, or an EagerTensor with the value NaN (we use tf.nan as a concrete TF value for units, see _tfval_remove_unit) or may even be a Tensor if we are building graphs and the NaN value is abstracted. """ def add_unit(v: TfValOrUnit, aval: core.AbstractValue): if not core.skip_checks: if aval is core.abstract_unit: if v is not core.unit: assert isinstance(v, tf.Tensor) if v.device: # Only for EagerTensor assert tf.math.is_nan(v) else: assert _is_tfval(v) return core.unit if aval is core.abstract_unit else v return util.safe_map(add_unit, vals, avals)