예제 #1
0
파일: unzip.py 프로젝트: yli96/probability
def _tracers_to_jaxpr(in_tracers, out_tracers):
    """Constructs Jaxpr given tracers for inputs and outputs.

  Copied from jax.interpreters.partial_eval.tracers_to_jaxpr but modified to
  raise an VariableError when unknown in_tracers are found, rather than the
  default AssertionError.

  Args:
    in_tracers: the tracers that were created for the function inputs
    out_tracers: the tracers that were output by the function.

  Returns:
    a triple of a `Jaxpr`, a list of constant values corresponding to
    the `constvars` in the returned Jaxps, and a list of environment values.
    The vars for the environment values have been pre-pended to the Jaxpr's
    `invars`.

  Raises:
    VariableError: if an unknown input tracer is found
  """
    newvar = jax_core.gensym(None)
    t_to_var = {}

    def getvar(t):
        var = t_to_var.get(id(t))
        if var is None:
            var = newvar(t.pval.get_aval())
            t_to_var[id(t)] = var
        return var

    sorted_tracers = jax_util.toposort(out_tracers)
    invars = safe_map(getvar, in_tracers)
    eqns = []
    env = {}
    consts = {}
    const_to_var = {}

    def getconstvar(c):
        var = const_to_var.get(id(c))
        if var is None:
            var = newvar(jax_core.get_aval(c))
            const_to_var[id(c)] = var
        return var

    processed_eqn_ids = set()
    for t in sorted_tracers:
        recipe = t.recipe
        if isinstance(recipe, pe.JaxprEqnRecipe):
            if recipe.eqn_id not in processed_eqn_ids:
                eqns.append(pe.recipe_to_eqn(getvar, recipe))
                processed_eqn_ids.add(recipe.eqn_id)
        elif isinstance(recipe, pe.LambdaBinding):
            if not any(t is in_tracer for in_tracer in in_tracers):
                raise VariableError('Found unknown input tracer.')
            assert in_tracers, 'Lambda binding with no args'
        elif isinstance(recipe, pe.FreeVar):
            env[getvar(t)] = recipe.val
        elif isinstance(recipe, pe.ConstVar):
            v = t_to_var[id(t)] = getconstvar(recipe.val)
            consts[v] = recipe.val
        elif isinstance(recipe, jax_core.Literal):
            t_to_var[id(t)] = recipe
        elif recipe is jax_core.unit:
            t_to_var[id(t)] = jax_core.unitvar
        else:
            raise TypeError(recipe)

    env_vars, env_vals = jax_util.unzip2(env.items())
    const_vars, const_vals = jax_util.unzip2(consts.items())
    # The env_vars are pre-pended to the invars
    jaxpr = jax_core.Jaxpr(const_vars, list(it.chain(env_vars, invars)),
                           safe_map(getvar, out_tracers), eqns)
    return jaxpr, const_vals, env_vals
예제 #2
0
파일: unzip.py 프로젝트: yli96/probability
def unzip_to_init_apply_subjaxprs(master, settings, keys, pvals):
    """Function transformation that returns init/apply jaxprs."""
    trace = UnzipTrace(master, jax_core.cur_sublevel())
    # Setting up input UnzipTracer objects
    in_tracers = safe_map(lambda a: trace.new_arg(a[0], a[1]),
                          zip(pvals, keys))
    key_tracers = [t for t in in_tracers if t.key]
    abstract_tracers = [t for t in in_tracers if not t.key]
    # Passing input tracers into function
    # to get output tracers
    context = UnzipContext(settings)
    with trace_util.new_dynamic_context(master, context):
        ans = yield in_tracers, {}
    out_tracers = safe_map(trace.full_raise, safe_map(jax_core.full_lower,
                                                      ans))
    out_pvals = [t.pval for t in out_tracers]

    all_tracers = jax_util.toposort(out_tracers)
    variable_tracers = [t for t in all_tracers if t.variable_recipe]
    if not settings.block:
        try:
            # This try/catch tests whether or not the variables define a cut of the
            # computation graph. `pe.tracers_to_jaxpr` throws an AssertionError
            # if that is the case.
            old_recipes = [t.recipe for t in variable_tracers]
            for t in variable_tracers:
                t.recipe = pe.LambdaBinding()
            _tracers_to_jaxpr(variable_tracers + abstract_tracers, out_tracers)
        except VariableError:
            success = False
        else:
            success = True
        finally:
            # Restore the old recipes if it fails
            for t, old_recipe in safe_zip(variable_tracers, old_recipes):
                t.recipe = old_recipe
    else:
        success = False
    if not success:
        jaxpr, consts, env = _tracers_to_jaxpr(in_tracers, out_tracers)
        out_keys = [t.is_key() for t in out_tracers]
        yield success, (jaxpr, (out_pvals, out_keys, consts, env))
        return

    variable_recipes = {}
    for t in all_tracers:
        if t.variable_recipe:
            name = t.variable_recipe.name
            if (name in variable_recipes
                    and variable_recipes[name] is not t.variable_recipe):
                raise ValueError(
                    'Cannot use duplicate variable name: {}'.format(name))
            variable_recipes[name] = t.variable_recipe

    variables = {
        name: (recipe.in_tracers, recipe.out_tracers)
        for name, recipe in variable_recipes.items()
    }
    variable_names, variable_tracers = jax_util.unzip2(variables.items())
    var_in_tracers, var_out_tracers = jax_util.unzip2(variable_tracers)
    flat_var_in_tracers, variable_tree = tree_util.tree_flatten(var_in_tracers)
    var_pvals = [t.pval for t in flat_var_in_tracers]
    flat_var_out_tracers, _ = tree_util.tree_flatten(var_out_tracers)
    init_jaxpr, init_consts, init_env = _tracers_to_jaxpr(
        key_tracers, flat_var_in_tracers)
    for t in flat_var_out_tracers:
        t.recipe = pe.LambdaBinding()
    apply_jaxpr, apply_consts, apply_env = _tracers_to_jaxpr(
        flat_var_out_tracers + abstract_tracers, out_tracers)
    if None in variable_names:
        raise ValueError('Must provide name for variable.')
    out_keys = [t.is_key() for t in out_tracers]
    yield success, ((init_jaxpr, init_consts,
                     init_env), (apply_jaxpr, apply_consts, apply_env),
                    (var_pvals, out_pvals), (variable_names, variable_tree,
                                             out_keys))