예제 #1
0
파일: unzip.py 프로젝트: yli96/probability
def unzip_to_init_apply_subjaxprs(master, settings, keys, pvals):
    """Function transformation that returns init/apply jaxprs."""
    trace = UnzipTrace(master, jax_core.cur_sublevel())
    # Setting up input UnzipTracer objects
    in_tracers = safe_map(lambda a: trace.new_arg(a[0], a[1]),
                          zip(pvals, keys))
    key_tracers = [t for t in in_tracers if t.key]
    abstract_tracers = [t for t in in_tracers if not t.key]
    # Passing input tracers into function
    # to get output tracers
    context = UnzipContext(settings)
    with trace_util.new_dynamic_context(master, context):
        ans = yield in_tracers, {}
    out_tracers = safe_map(trace.full_raise, safe_map(jax_core.full_lower,
                                                      ans))
    out_pvals = [t.pval for t in out_tracers]

    all_tracers = jax_util.toposort(out_tracers)
    variable_tracers = [t for t in all_tracers if t.variable_recipe]
    if not settings.block:
        try:
            # This try/catch tests whether or not the variables define a cut of the
            # computation graph. `pe.tracers_to_jaxpr` throws an AssertionError
            # if that is the case.
            old_recipes = [t.recipe for t in variable_tracers]
            for t in variable_tracers:
                t.recipe = pe.LambdaBinding()
            _tracers_to_jaxpr(variable_tracers + abstract_tracers, out_tracers)
        except VariableError:
            success = False
        else:
            success = True
        finally:
            # Restore the old recipes if it fails
            for t, old_recipe in safe_zip(variable_tracers, old_recipes):
                t.recipe = old_recipe
    else:
        success = False
    if not success:
        jaxpr, consts, env = _tracers_to_jaxpr(in_tracers, out_tracers)
        out_keys = [t.is_key() for t in out_tracers]
        yield success, (jaxpr, (out_pvals, out_keys, consts, env))
        return

    variable_recipes = {}
    for t in all_tracers:
        if t.variable_recipe:
            name = t.variable_recipe.name
            if (name in variable_recipes
                    and variable_recipes[name] is not t.variable_recipe):
                raise ValueError(
                    'Cannot use duplicate variable name: {}'.format(name))
            variable_recipes[name] = t.variable_recipe

    variables = {
        name: (recipe.in_tracers, recipe.out_tracers)
        for name, recipe in variable_recipes.items()
    }
    variable_names, variable_tracers = jax_util.unzip2(variables.items())
    var_in_tracers, var_out_tracers = jax_util.unzip2(variable_tracers)
    flat_var_in_tracers, variable_tree = tree_util.tree_flatten(var_in_tracers)
    var_pvals = [t.pval for t in flat_var_in_tracers]
    flat_var_out_tracers, _ = tree_util.tree_flatten(var_out_tracers)
    init_jaxpr, init_consts, init_env = _tracers_to_jaxpr(
        key_tracers, flat_var_in_tracers)
    for t in flat_var_out_tracers:
        t.recipe = pe.LambdaBinding()
    apply_jaxpr, apply_consts, apply_env = _tracers_to_jaxpr(
        flat_var_out_tracers + abstract_tracers, out_tracers)
    if None in variable_names:
        raise ValueError('Must provide name for variable.')
    out_keys = [t.is_key() for t in out_tracers]
    yield success, ((init_jaxpr, init_consts,
                     init_env), (apply_jaxpr, apply_consts, apply_env),
                    (var_pvals, out_pvals), (variable_names, variable_tree,
                                             out_keys))
예제 #2
0
파일: unzip.py 프로젝트: yli96/probability
 def new_arg(self, pval, key):
     return UnzipTracer(self, pval, pe.LambdaBinding(), key)