Exemplo n.º 1
0
def inline_calls(jaxpr):
    new_eqns = []

    def inline_call(jaxpr, invars, outvars):
        inmap = dict(zip(jaxpr.invars, invars))
        outmap = dict(zip(jaxpr.outvars, outvars))
        for eqn in jaxpr.eqns:
            new_invars = [
                v if isinstance(v, Literal) else inmap.get(v, v)
                for v in eqn.invars
            ]
            new_outvars = [outmap.get(v, v) for v in eqn.outvars]
            call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive,
                                                       eqn.params)
            if call_jaxpr:
                if not eqn.primitive in {jc.call_p, xla.xla_call_p}:
                    raise NotImplementedError
                inline_call(call_jaxpr, new_invars, new_outvars)
            else:
                new_eqns.append(
                    JaxprEqn(new_invars, new_outvars, eqn.primitive,
                             eqn.params, eqn.source_info))

    for eqn in jaxpr.eqns:
        call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive, eqn.params)
        if call_jaxpr:
            if not eqn.primitive in {jc.call_p, xla.xla_call_p}:
                raise NotImplementedError
            inline_call(call_jaxpr, eqn.invars, eqn.outvars)
        else:
            new_eqns.append(eqn)

    return Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
Exemplo n.º 2
0
def lazy_eval_jaxpr(jaxpr, consts, *args):
    def read(v):
        if type(v) in {jc.Literal, Literal_}:
            return v.val
        else:
            return env[v]

    def write(v, val):
        env[v] = val

    env = {}
    write(jc.unitvar, jc.unit)
    map(write, jaxpr.constvars, consts)
    map(write, jaxpr.invars, args)
    for eqn in jaxpr.eqns:
        call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive, eqn.params)
        if call_jaxpr:
            raise NotImplementedError
        map(write, eqn.outvars, map(LazyArray, eqn.outvars))
    for eqn in jaxpr.eqns:
        invals = map(read, eqn.invars)
        outvals = map(read, eqn.outvars)
        new_eqn = jc.JaxprEqn(invals, outvals, eqn.primitive, eqn.params,
                              eqn.source_info)
        map(lambda arr: arr.set_eqn(new_eqn), outvals)
    return map(read, jaxpr.outvars)
def jaxpr_to_expressions(jaxpr: jax_core.Jaxpr) -> Tuple[Expr]:
    """Converts a JAXpr into a tuple of output `JaxExpression`s.

  Args:
    jaxpr: a `jax.core.Jaxpr` to be converted into a tuple of `JaxExpression`s.
  Returns:
    A tuple of `JaxExpression`s.
  """
    env = {}

    def read_env(var: jax_core.Var) -> Any:
        if isinstance(var, jax_core.Literal):
            return Literal(var.val)
        return env[str(var)]

    def write_env(var: jax_core.Var, val: Any) -> None:
        if isinstance(var, jax_core.Literal):
            return
        env[str(var)] = val

    const_patterns = jax_util.safe_map(
        lambda var: JaxVar(str(var), var.aval.shape, var.aval.dtype),
        jaxpr.constvars)
    jax_util.safe_map(write_env, jaxpr.constvars, const_patterns)

    in_patterns = jax_util.safe_map(
        lambda var: JaxVar(str(var), var.aval.shape, var.aval.dtype),
        jaxpr.invars)
    jax_util.safe_map(write_env, jaxpr.invars, in_patterns)
    for eqn in jaxpr.eqns:
        operands = tuple(jax_util.safe_map(read_env, eqn.invars))

        call_jaxpr, params = jax_core.extract_call_jaxpr(
            eqn.primitive, eqn.params)
        if call_jaxpr:
            call_expression = BoundExpression(jaxpr_to_expressions(call_jaxpr),
                                              {})
            variable_names = tuple(map(str, call_jaxpr.invars))
            out = CallPrimitive(eqn.primitive, operands, call_expression,
                                Params(params), variable_names)
        else:
            out = primitive_to_expression(eqn.primitive)(operands,
                                                         Params(params))
        if eqn.primitive.multiple_results:
            out_parts = [Part(out, i) for i in range(len(eqn.outvars))]
            jax_util.safe_map(write_env, eqn.outvars, out_parts)
        else:
            write_env(eqn.outvars[0], out)
    return tuple(jax_util.safe_map(read_env, jaxpr.outvars))
Exemplo n.º 4
0
def eval_jaxpr_with_state(jaxpr: jax_core.Jaxpr, rules: Rules,
                          consts: Sequence[Value], state: Value,
                          *args: Value) -> Tuple[List[Value], Value]:
    """Interprets a JAXpr and manages an input state with primitive rules.

  The implementation follows `jax.core.eval_jaxpr` closely; the main differences
  are:
  1. Rather than always calling `primitive.bind`, `eval_jaxpr_with_state`
     looks up a rule for the primitive in the provided `rules` dictionary first.
  2. A `state` value is provided that is threaded through the execution of the
     JAXpr and whose final value is returned as an additional output.

  Args:
    jaxpr: The JAXpr to be interpreted.
    rules: A `dict` that maps JAX primitives to functions that take in `(state,
      *args)` and return `(output, new_state)`.
    consts: A list of constant values corresponding to the JAXpr's constvars.
    state: The initial state for the interpreter.
    *args: A list of values that correspond to the JAXpr's invars.

  Returns:
    A list of outputs from the JAXpr and the final state.
  """
    env = Environment()

    jax_util.safe_map(env.write, jaxpr.constvars, consts)
    jax_util.safe_map(env.write, jaxpr.invars, args)

    for eqn in jaxpr.eqns:
        invals = jax_util.safe_map(env.read, eqn.invars)
        call_jaxpr, params = jax_core.extract_call_jaxpr(
            eqn.primitive, eqn.params)
        if call_jaxpr:
            call_rule = _effect_handler_call_rules.get(
                eqn.primitive,
                functools.partial(default_call_interpreter_rule,
                                  eqn.primitive))
            ans, state = call_rule(rules, state, invals, call_jaxpr, **params)
        elif eqn.primitive in rules:
            ans, state = rules[eqn.primitive](state, *invals, **params)
        else:
            ans = eqn.primitive.bind(*invals, **params)
        if eqn.primitive.multiple_results:
            jax_util.safe_map(env.write, eqn.outvars, ans)
        else:
            env.write(eqn.outvars[0], ans)
    return jax_util.safe_map(env.read, jaxpr.outvars), state
Exemplo n.º 5
0
def eval_jaxpr_with_kwargs(jaxpr: jax_core.Jaxpr, consts: Iterable[Any], *args,
                           **kwargs):
    """Evals a jaxpr while passing kwargs into registered primitives."""
    def read(v):
        if isinstance(v, jax_core.Literal):
            return v.val
        else:
            return env[v]

    def write(v, val):
        env[v] = val

    env = {}
    write(jax_core.unitvar, jax_core.unit)
    safe_map(write, jaxpr.constvars, consts)
    safe_map(write, jaxpr.invars, args)
    for eqn in jaxpr.eqns:
        in_vals = safe_map(read, eqn.invars)
        subjaxpr, params = jax_core.extract_call_jaxpr(eqn.primitive,
                                                       eqn.params)
        if subjaxpr:
            subfuns = [
                lu.wrap_init(
                    jax_core.partial(eval_jaxpr_with_kwargs, subjaxpr, (),
                                     **kwargs))
            ]
        else:
            subfuns = []
            params = dict(eqn.params)
        if eqn.primitive in kwargs_rules:
            new_kwargs = dict(params.pop('kwargs', {}), **kwargs)
            ans = kwargs_rules[eqn.primitive](*(subfuns + in_vals),
                                              kwargs=new_kwargs,
                                              **params)
        else:
            ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
        if eqn.primitive.multiple_results:
            safe_map(write, eqn.outvars, ans)
        else:
            write(eqn.outvars[0], ans)
    return safe_map(read, jaxpr.outvars)
Exemplo n.º 6
0
def propagate(cell_type: Type[Cell], rules: Dict[jax_core.Primitive,
                                                 PropagationRule],
              jaxpr: pe.Jaxpr, constcells: List[Cell], incells: List[Cell],
              outcells: List[Cell]) -> Environment:
    """Propagates cells in a Jaxpr using a set of rules.

  Args:
    cell_type: used to instantiate literals into cells
    rules: maps JAX primitives to propagation rule functions
    jaxpr: used to construct the propagation graph
    constcells: used to populate the Jaxpr's constvars
    incells: used to populate the Jaxpr's invars
    outcells: used to populate the Jaxpr's outcells
  Returns:
    The Jaxpr environment after propagation has terminated
  """
    env = Environment(cell_type, jaxpr)
    safe_map(env.write, jaxpr.constvars, constcells)
    safe_map(env.write, jaxpr.invars, incells)
    safe_map(env.write, jaxpr.outvars, outcells)

    eqns = safe_map(Equation.from_jaxpr_eqn, jaxpr.eqns)

    get_neighbor_eqns = construct_graph_representation(eqns)
    # Initialize propagation queue with equations neighboring constvars, invars,
    # and outvars.
    out_eqns = set()
    for eqn in jaxpr.eqns:
        for var in it.chain(eqn.invars, eqn.outvars):
            env.write(var, cell_type.unknown(var.aval))

    for var in it.chain(jaxpr.outvars, jaxpr.invars, jaxpr.constvars):
        out_eqns.update(get_neighbor_eqns(var))
    queue = collections.deque(out_eqns)
    while queue:
        eqn = queue.popleft()

        incells = safe_map(env.read, eqn.invars)
        outcells = safe_map(env.read, eqn.outvars)

        rule = rules[eqn.primitive]
        call_jaxpr, params = jax_core.extract_call_jaxpr(
            eqn.primitive, eqn.params)
        if call_jaxpr:
            subfuns = [
                lu.wrap_init(
                    functools.partial(propagate, cell_type, rules, call_jaxpr,
                                      ()))
            ]
        else:
            subfuns = []
        new_incells, new_outcells, subenv = rule(subfuns + incells, outcells,
                                                 **params)
        if subenv:
            env.write_subenv(eqn, subenv)

        new_incells = safe_map(env.write, eqn.invars, new_incells)
        new_outcells = safe_map(env.write, eqn.outvars, new_outcells)

        update_queue_state(queue, eqn, get_neighbor_eqns, incells, outcells,
                           new_incells, new_outcells)
    return env
Exemplo n.º 7
0
def propagate(cell_type: Type[Cell],
              rules: Dict[jax_core.Primitive, PropagationRule],
              jaxpr: pe.Jaxpr,
              constcells: List[Cell],
              incells: List[Cell],
              outcells: List[Cell],
              reducer: Callable[[Environment, Equation, State, State],
                                State] = identity_reducer,
              initial_state: State = None) -> Tuple[Environment, State]:
    """Propagates cells in a Jaxpr using a set of rules.

  Args:
    cell_type: used to instantiate literals into cells
    rules: maps JAX primitives to propagation rule functions
    jaxpr: used to construct the propagation graph
    constcells: used to populate the Jaxpr's constvars
    incells: used to populate the Jaxpr's invars
    outcells: used to populate the Jaxpr's outcells
    reducer: An optional callable used to reduce over the state at each
      equation in the Jaxpr. `reducer` takes in `(env, eqn, state, new_state)`
      as arguments and should return an updated state. The `new_state` value
      is provided by each equation.
    initial_state: The initial `state` value used in the reducer
  Returns:
    The Jaxpr environment after propagation has terminated
  """
    env = Environment(cell_type, jaxpr)
    safe_map(env.write, jaxpr.constvars, constcells)
    safe_map(env.write, jaxpr.invars, incells)
    safe_map(env.write, jaxpr.outvars, outcells)

    eqns = safe_map(Equation.from_jaxpr_eqn, jaxpr.eqns)

    get_neighbor_eqns = construct_graph_representation(eqns)
    # Initialize propagation queue with equations neighboring constvars, invars,
    # and outvars.
    out_eqns = set()
    for eqn in jaxpr.eqns:
        for var in it.chain(eqn.invars, eqn.outvars):
            env.write(var, cell_type.unknown(var.aval))

    for var in it.chain(jaxpr.outvars, jaxpr.invars, jaxpr.constvars):
        out_eqns.update(get_neighbor_eqns(var))
    queue = collections.deque(out_eqns)
    while queue:
        eqn = queue.popleft()

        incells = safe_map(env.read, eqn.invars)
        outcells = safe_map(env.read, eqn.outvars)

        call_jaxpr, params = jax_core.extract_call_jaxpr(
            eqn.primitive, eqn.params)
        if call_jaxpr:
            subfuns = [
                lu.wrap_init(
                    functools.partial(propagate,
                                      cell_type,
                                      rules,
                                      call_jaxpr, (),
                                      initial_state=initial_state,
                                      reducer=reducer))
            ]
            if eqn.primitive not in rules:
                rule = default_call_rules.get(eqn.primitive)
            else:
                rule = rules[eqn.primitive]
        else:
            subfuns = []
            rule = rules[eqn.primitive]
        new_incells, new_outcells, eqn_state = rule(subfuns + incells,
                                                    outcells, **params)
        env.write_state(eqn, eqn_state)

        new_incells = safe_map(env.write, eqn.invars, new_incells)
        new_outcells = safe_map(env.write, eqn.outvars, new_outcells)

        update_queue_state(queue, eqn, get_neighbor_eqns, incells, outcells,
                           new_incells, new_outcells)
    state = initial_state
    for eqn in eqns:
        state = reducer(env, eqn, state, env.read_state(eqn))
    return env, state
Exemplo n.º 8
0
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in,
                  cotangents_in):
    if all(type(ct) is Zero for ct in cotangents_in):
        return map(lambda v: Zero(v.aval), jaxpr.invars)

    def write_cotangent(prim, v, ct):
        # assert v not in primal_env
        assert ct is not Zero, (prim, v.aval
                                )  # check for an old harmless type error
        if ct is None or type(v) is Literal:
            return
        if type(ct) is Zero:
            # FIXME: This triggers a lot of failures!
            # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
            return
        axes_to_reduce = tuple(axis_name for axis_name in reduce_axes
                               if axis_name in core.get_aval(ct).named_shape
                               and axis_name not in v.aval.named_shape)
        if axes_to_reduce:
            ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
        ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
        if config.jax_enable_checks:
            ct_aval = core.get_aval(ct_env[v])
            joined_aval = core.lattice_join(
                v.aval, ct_aval).strip_weak_type().strip_named_shape()
            assert v.aval.strip_weak_type().strip_named_shape(
            ) == joined_aval, (prim, v.aval, ct_aval)

    def read_cotangent(v):
        return ct_env.pop(v, Zero(v.aval))

    def read_primal(v):
        if type(v) is Literal:
            return v.val
        else:
            return primal_env.get(v, UndefinedPrimal(v.aval))

    def write_primal(v, val):
        if not is_undefined_primal(val):
            primal_env[v] = val

    primal_env: Dict[Any, Any] = {}
    write_primal(core.unitvar, core.unit)
    map(write_primal, jaxpr.constvars, consts)
    # FIXME: invars can contain both primal and tangent values, and this line
    #        forces primal_in to contain UndefinedPrimals for tangent values!
    map(write_primal, jaxpr.invars, primals_in)

    ct_env: Dict[Any, Any] = {}
    map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
    for eqn in jaxpr.eqns[::-1]:
        # FIXME: Some invars correspond to tangents
        invals = map(read_primal, eqn.invars)
        if eqn.primitive.multiple_results:
            cts_in = map(read_cotangent, eqn.outvars)
        else:
            cts_in, = map(read_cotangent, eqn.outvars)
        with source_info_util.user_context(eqn.source_info.traceback):
            if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
                cts_in_avals = [v.aval for v in eqn.outvars]
                call_jaxpr, params = core.extract_call_jaxpr(
                    eqn.primitive, eqn.params)
                cts_out = get_primitive_transpose(
                    eqn.primitive)(params, call_jaxpr, invals, cts_in,
                                   cts_in_avals, reduce_axes)
            elif eqn.primitive in reducing_transposes:
                cts_out = reducing_transposes[eqn.primitive](reduce_axes,
                                                             cts_in, *invals,
                                                             **eqn.params)
            else:
                cts_out = get_primitive_transpose(eqn.primitive)(cts_in,
                                                                 *invals,
                                                                 **eqn.params)
        cts_out = [Zero(v.aval)
                   for v in eqn.invars] if cts_out is Zero else cts_out
        # FIXME: Some invars correspond to primals!
        map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)

    cotangents_out = map(read_cotangent, jaxpr.invars)
    return cotangents_out
Exemplo n.º 9
0
def submerge_consts(jaxpr, consts, invals=None):
    """
  Replace constvars with literals in jaxpr and its sub-jaxprs.
  """
    # TODO(j-towns): check that consts are in jax.core.literalable_types
    consts = dict(zip(jaxpr.constvars, consts))
    if invals is not None:
        # We're in a call_jaxpr
        new_jaxpr_invars = []
        for var, val in zip(jaxpr.invars, invals):
            if isinstance(val, Var):
                new_jaxpr_invars.append(var)
            else:
                consts[var] = val
    else:
        new_jaxpr_invars = jaxpr.invars
    new_eqns = []
    for eqn in jaxpr.eqns:
        if all(
                isinstance(var, Literal) or var in consts
                for var in eqn.invars):
            # Perform constant folding if all inputs to an eqn are known
            in_vals = [
                var.val if isinstance(var, Literal) else consts[var]
                for var in eqn.invars
            ]
            call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive,
                                                       eqn.params)
            if call_jaxpr:
                subfuns = [
                    lu.wrap_init(partial(jc.eval_jaxpr, call_jaxpr, ()))
                ]
            else:
                subfuns = []
            ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
            if eqn.primitive.multiple_results:
                for outvar, out in zip(eqn.outvars, ans):
                    consts[outvar] = out
            else:
                outvar, = eqn.outvars
                consts[outvar] = ans
        else:
            new_invars = [
                consts[var] if
                (isinstance(var, Var) and var in consts) else var
                for var in eqn.invars
            ]
            new_params = dict(eqn.params)
            if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
                new_params['call_jaxpr'] = submerge_consts(
                    eqn.params['call_jaxpr'], [], new_invars)
                new_invars = [
                    var for var in new_invars if isinstance(var, Var)
                ]
            else:
                new_invars = [
                    var if isinstance(var, (Var, Literal)) else Literal_(var)
                    for var in new_invars
                ]
            new_eqns.append(
                JaxprEqn(invars=new_invars,
                         outvars=eqn.outvars,
                         primitive=eqn.primitive,
                         params=new_params,
                         source_info=eqn.source_info))
    return Jaxpr([], new_jaxpr_invars, jaxpr.outvars, new_eqns)