def unzip_to_init_apply_subjaxprs(master, settings, keys, pvals): """Function transformation that returns init/apply jaxprs.""" trace = UnzipTrace(master, jax_core.cur_sublevel()) # Setting up input UnzipTracer objects in_tracers = safe_map(lambda a: trace.new_arg(a[0], a[1]), zip(pvals, keys)) key_tracers = [t for t in in_tracers if t.key] abstract_tracers = [t for t in in_tracers if not t.key] # Passing input tracers into function # to get output tracers context = UnzipContext(settings) with trace_util.new_dynamic_context(master, context): ans = yield in_tracers, {} out_tracers = safe_map(trace.full_raise, safe_map(jax_core.full_lower, ans)) out_pvals = [t.pval for t in out_tracers] all_tracers = jax_util.toposort(out_tracers) variable_tracers = [t for t in all_tracers if t.variable_recipe] if not settings.block: try: # This try/catch tests whether or not the variables define a cut of the # computation graph. `pe.tracers_to_jaxpr` throws an AssertionError # if that is the case. old_recipes = [t.recipe for t in variable_tracers] for t in variable_tracers: t.recipe = pe.LambdaBinding() _tracers_to_jaxpr(variable_tracers + abstract_tracers, out_tracers) except VariableError: success = False else: success = True finally: # Restore the old recipes if it fails for t, old_recipe in safe_zip(variable_tracers, old_recipes): t.recipe = old_recipe else: success = False if not success: jaxpr, consts, env = _tracers_to_jaxpr(in_tracers, out_tracers) out_keys = [t.is_key() for t in out_tracers] yield success, (jaxpr, (out_pvals, out_keys, consts, env)) return variable_recipes = {} for t in all_tracers: if t.variable_recipe: name = t.variable_recipe.name if (name in variable_recipes and variable_recipes[name] is not t.variable_recipe): raise ValueError( 'Cannot use duplicate variable name: {}'.format(name)) variable_recipes[name] = t.variable_recipe variables = { name: (recipe.in_tracers, recipe.out_tracers) for name, recipe in variable_recipes.items() } variable_names, variable_tracers = jax_util.unzip2(variables.items()) var_in_tracers, var_out_tracers = jax_util.unzip2(variable_tracers) flat_var_in_tracers, variable_tree = tree_util.tree_flatten(var_in_tracers) var_pvals = [t.pval for t in flat_var_in_tracers] flat_var_out_tracers, _ = tree_util.tree_flatten(var_out_tracers) init_jaxpr, init_consts, init_env = _tracers_to_jaxpr( key_tracers, flat_var_in_tracers) for t in flat_var_out_tracers: t.recipe = pe.LambdaBinding() apply_jaxpr, apply_consts, apply_env = _tracers_to_jaxpr( flat_var_out_tracers + abstract_tracers, out_tracers) if None in variable_names: raise ValueError('Must provide name for variable.') out_keys = [t.is_key() for t in out_tracers] yield success, ((init_jaxpr, init_consts, init_env), (apply_jaxpr, apply_consts, apply_env), (var_pvals, out_pvals), (variable_names, variable_tree, out_keys))
def new_arg(self, pval, key): return UnzipTracer(self, pval, pe.LambdaBinding(), key)