コード例 #1
0
        def vjp_func(tangents):
            all_tangents = aux_vjp(tangents)
            tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:]
            inputs_tangents = jax.tree_flatten(inputs_tangents)[0]
            tangents_dict.update(zip(jaxpr.invars, inputs_tangents))

            read_primals = functools.partial(tgm.read_env, primals_dict)
            read_tangents = functools.partial(tgm.read_env, tangents_dict)
            layers_info = []
            for jaxpr_eqn in layer_tags:
                layer_tag = _unbox_layer_tag(jaxpr_eqn)
                info = dict()
                primals = jax_util.safe_map(read_primals,
                                            tuple(jaxpr_eqn.invars))
                (
                    info["outputs"],
                    info["inputs"],
                    info["params"],
                ) = layer_tag.split_all_inputs(primals)
                tangents = jax_util.safe_map(read_tangents,
                                             tuple(jaxpr_eqn.invars))
                (
                    info["outputs_tangent"],
                    info["inputs_tangent"],
                    info["params_tangent"],
                ) = layer_tag.split_all_inputs(tangents)
                layers_info.append(info)
            return tuple(layers_info)
コード例 #2
0
    def forward_compute_losses(
        params_primals: Any, ) -> Sequence[Sequence[jnp.ndarray]]:
        primals_[params_index] = params_primals
        flat_args = jax.tree_flatten(primals_)[0]
        # Mapping from variable -> value
        env = dict()
        read = functools.partial(tgm.read_env, env)
        write = functools.partial(tgm.write_env, env)

        # Bind args and consts to environment
        write(jax.core.unitvar, jax.core.unit)
        jax_util.safe_map(write, jaxpr.invars, flat_args)
        jax_util.safe_map(write, jaxpr.constvars, consts)

        # Loop through equations and evaluate primitives using `bind`
        losses_so_far = 0
        loss_tags = []
        for eqn in jaxpr.eqns:
            tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
            if isinstance(eqn.primitive, tags.LossTag):
                loss_tags.append(eqn)
                losses_so_far += 1
            if num_losses is not None and losses_so_far == num_losses:
                break
        return tuple(tuple(read(v) for v in tag.invars) for tag in loss_tags)
コード例 #3
0
    def trace_to_jaxpr_finalize(in_tracers,
                                out_tracers,
                                trace,
                                instantiate=True):
        # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
        instantiate = [instantiate] * len(out_tracers)
        out_tracers = safe_map(trace.full_raise,
                               safe_map(core.full_lower, out_tracers))
        out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                               instantiate, out_tracers)
        jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
        out_pvals = [t.pval for t in out_tracers]
        # TODO: this is from partial_eval.trace_to_jaxpr. Share.
        assert not env

        # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr
        out_avals = safe_map(abstract_arrays.raise_to_shaped,
                             unzip2(out_pvals)[0])
        const_avals = tuple(
            abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts)

        in_pvals = [t.pval for t in in_tracers]
        in_avals = tuple(
            safe_map(abstract_arrays.raise_to_shaped,
                     unzip2(in_pvals)[0]))

        typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (),
                                      const_avals + in_avals, out_avals)
        return typed_jaxpr, consts
コード例 #4
0
ファイル: loops.py プロジェクト: self-supervisor/jax
 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True):
   # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
   instantiate = [instantiate] * len(out_tracers)
   out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
   out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                          instantiate, out_tracers)
   jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
   assert not env  # TODO: this is from partial_eval.trace_to_jaxpr. Share.
   closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   return closed_jaxpr, consts
コード例 #5
0
def plant_function(main: jax_core.MainTrace, settings: HarvestSettings,
                   in_tree: Any, args: Iterable[Any]):
    """A function transformation that injects values in place of sows."""
    trace = HarvestTrace(main, jax_core.cur_sublevel())
    plants, args = tree_util.tree_unflatten(in_tree, args)
    args = jax_util.safe_map(trace.pure, args)
    context = PlantContext(settings, plants)
    with trace_util.new_dynamic_context(main, context):
        ans = yield args, {}
        out_tracers = jax_util.safe_map(trace.full_raise, ans)
        del main
    yield [t.val for t in out_tracers]
コード例 #6
0
def jaxpr_to_expressions(jaxpr: jax_core.Jaxpr) -> Tuple[Expr]:
    """Converts a JAXpr into a tuple of output `JaxExpression`s.

  Args:
    jaxpr: a `jax.core.Jaxpr` to be converted into a tuple of `JaxExpression`s.
  Returns:
    A tuple of `JaxExpression`s.
  """
    env = {}

    def read_env(var: jax_core.Var) -> Any:
        if isinstance(var, jax_core.Literal):
            return Literal(var.val)
        return env[str(var)]

    def write_env(var: jax_core.Var, val: Any) -> None:
        if isinstance(var, jax_core.Literal):
            return
        env[str(var)] = val

    const_patterns = jax_util.safe_map(
        lambda var: JaxVar(str(var), var.aval.shape, var.aval.dtype),
        jaxpr.constvars)
    jax_util.safe_map(write_env, jaxpr.constvars, const_patterns)

    in_patterns = jax_util.safe_map(
        lambda var: JaxVar(str(var), var.aval.shape, var.aval.dtype),
        jaxpr.invars)
    jax_util.safe_map(write_env, jaxpr.invars, in_patterns)
    for eqn in jaxpr.eqns:
        operands = tuple(jax_util.safe_map(read_env, eqn.invars))

        call_jaxpr, params = jax_core.extract_call_jaxpr(
            eqn.primitive, eqn.params)
        if call_jaxpr:
            call_expression = BoundExpression(jaxpr_to_expressions(call_jaxpr),
                                              {})
            variable_names = tuple(map(str, call_jaxpr.invars))
            out = CallPrimitive(eqn.primitive, operands, call_expression,
                                Params(params), variable_names)
        else:
            out = primitive_to_expression(eqn.primitive)(operands,
                                                         Params(params))
        if eqn.primitive.multiple_results:
            out_parts = [Part(out, i) for i in range(len(eqn.outvars))]
            jax_util.safe_map(write_env, eqn.outvars, out_parts)
        else:
            write_env(eqn.outvars[0], out)
    return tuple(jax_util.safe_map(read_env, jaxpr.outvars))
コード例 #7
0
 def default_process_primitive(
     self, primitive: jax_core.Primitive, tracers: List['HarvestTracer'],
     params: Dict[str, Any]) -> Union['HarvestTracer', List['HarvestTracer']]:
   context = trace_util.get_dynamic_context(self)
   vals = [t.val for t in tracers]
   if primitive is sow_p:
     outvals = context.process_sow(*vals, **params)
     return jax_util.safe_map(self.pure, outvals)
   outvals = primitive.bind(*vals, **params)
   if not primitive.multiple_results:
     outvals = [outvals]
   out_tracers = jax_util.safe_map(self.pure, outvals)
   if primitive.multiple_results:
     return out_tracers
   return out_tracers[0]
コード例 #8
0
ファイル: transform.py プロジェクト: yashk2810/jax
def eval_sparse(
    jaxpr: core.Jaxpr,
    consts: Sequence[Array],  # all consts are dense
    argspecs: Sequence[ArgSpec],  # mix of sparse and dense pointers into spenv
    spenv: SparseEnv,
) -> Sequence[ArgSpec]:
    env: Dict[core.Var, ArgSpec] = {}

    def read(var: core.Var) -> Union[Array, ArgSpec]:
        # all literals are dense
        if isinstance(var, core.Literal):
            return ArgSpec(np.shape(var.val), spenv.push(var.val), None)
        else:
            return env[var]

    def write_buffer(var: core.Var, a: Array) -> None:
        if var is core.dropvar:
            return
        env[var] = ArgSpec(a.shape, spenv.push(a), None)

    def write(var: core.Var, a: ArgSpec) -> None:
        if var is core.dropvar:
            return
        env[var] = a

    # TODO: handle unitvar at all?
    #write_buffer(core.unitvar, core.unit)
    safe_map(write_buffer, jaxpr.constvars, consts)
    safe_map(write, jaxpr.invars, argspecs)

    for eqn in jaxpr.eqns:
        prim = eqn.primitive
        invals = safe_map(read, eqn.invars)

        if any(val.is_sparse() for val in invals):
            if prim not in sparse_rules:
                raise NotImplementedError(f"sparse rule for {prim}")
            out = sparse_rules[prim](spenv, *invals, **eqn.params)
        else:
            if prim is xla.xla_call_p:
                # TODO(vanderplas,frostig): workaround for binding call primitives
                # within a jaxpr interpreter
                params = eqn.params.copy()
                fun = lu.wrap_init(
                    core.jaxpr_as_fun(
                        pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
                out_bufs = prim.bind(fun, *(val.data(spenv) for val in invals),
                                     **params)
            else:
                out_bufs = prim.bind(*(val.data(spenv) for val in invals),
                                     **eqn.params)
            out_bufs = out_bufs if prim.multiple_results else [out_bufs]
            out = []
            for buf in out_bufs:
                out.append(ArgSpec(buf.shape, spenv.push(buf), None))
        safe_map(write, eqn.outvars, out)

    return safe_map(read, jaxpr.outvars)
コード例 #9
0
def _plant_cond_rule(trace, *tracers, branches, linear):
    """Injects the same values into both branches of a conditional."""
    index_tracer, ops_tracers = tracers[0], tracers[1:]
    index_val, ops_vals = tree_util.tree_map(lambda x: x.val,
                                             (index_tracer, ops_tracers))
    ops_avals = tree_util.tree_map(lambda x: x.aval, ops_tracers)
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    plant_settings = dict(tag=settings.tag,
                          allowlist=settings.allowlist,
                          blocklist=settings.blocklist,
                          exclusive=settings.exclusive)
    branch_metadatas = tuple(
        _get_harvest_metadata(branch, settings, *ops_tracers)
        for branch in branches)
    _check_branch_metadata(branch_metadatas)
    plants = context.plants
    branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches))
    planted_branches = tuple(
        functools.partial(plant(f, **plant_settings), plants)
        for f in branch_funs)
    in_tree = tree_util.tree_structure(ops_avals)
    new_branch_jaxprs, consts, _ = (
        lcf._initial_style_jaxprs_with_common_consts(  # pylint: disable=protected-access
            planted_branches, in_tree, ops_avals, lax.cond_p.name))
    out = lax.cond_p.bind(index_val,
                          *(tuple(consts) + ops_vals),
                          branches=tuple(new_branch_jaxprs),
                          linear=(False, ) * len(tuple(consts) + linear))
    return jax_util.safe_map(trace.pure, out)
コード例 #10
0
def _reap_cond_rule(trace, *tracers, branches, linear):
    """Reaps each path of the `cond`."""
    index_tracer, ops_tracers = tracers[0], tracers[1:]
    index_val, ops_vals = tree_util.tree_map(lambda x: x.val,
                                             (index_tracer, ops_tracers))
    _, ops_avals = tree_util.tree_map(lambda x: x.aval,
                                      (index_tracer, ops_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    reap_settings = dict(tag=settings.tag,
                         allowlist=settings.allowlist,
                         blocklist=settings.blocklist,
                         exclusive=settings.exclusive)
    branch_metadatas = tuple(
        _get_harvest_metadata(branch, settings, *ops_tracers)
        for branch in branches)
    _check_branch_metadata(branch_metadatas)
    branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches))
    reaped_branches = tuple(
        call_and_reap(f, **reap_settings) for f in branch_funs)
    in_tree = tree_util.tree_structure(ops_avals)
    new_branch_jaxprs, consts, out_trees = (
        lcf._initial_style_jaxprs_with_common_consts(  # pylint: disable=protected-access
            reaped_branches, in_tree, ops_avals, lax.cond_p.name))
    out = lax.cond_p.bind(index_val,
                          *(tuple(consts) + ops_vals),
                          branches=tuple(new_branch_jaxprs),
                          linear=(False, ) * len(tuple(consts) + linear))
    out = jax_util.safe_map(trace.pure, out)
    out, reaps = tree_util.tree_unflatten(out_trees[0], out)
    for k, v in reaps.items():
        sow(v, name=k, tag=settings.tag, mode=branch_metadatas[0][k]['mode'])
    return out
コード例 #11
0
ファイル: jax2tf.py プロジェクト: ekelsen/jax
    def process_primitive(self, primitive: core.Primitive,
                          tracers: Sequence[TensorFlowTracer],
                          params) -> TensorFlowTracer:
        impl = self.get_primitive_impl(primitive)
        args_tf: Sequence[TfValOrUnit] = [t.val for t in tracers]
        # impl takes core.unit and returns core.unit when needed.
        val_out: TfValOrUnit = impl(*args_tf, **params)
        if primitive.multiple_results:
            out = util.safe_map(functools.partial(TensorFlowTracer, self),
                                val_out)  # type: ignore
        else:
            out = TensorFlowTracer(self, val_out)

        # Check that the impl rule returned a value of expected shape and dtype
        if not core.skip_checks:
            expected_out_aval: core.AbstractValue = primitive.abstract_eval(
                *[t.aval for t in tracers], **params)
            if primitive.multiple_results:
                for o, expected_aval in zip(out,
                                            expected_out_aval):  # type: ignore
                    assert o.aval == expected_aval, (
                        f"{primitive}: out.aval = {o.aval}; expected {expected_aval}"
                    )
            else:
                assert out.aval == expected_out_aval, (  # type: ignore
                    f"{primitive}: out.aval = {out.aval}; expected {expected_out_aval}"
                )  # type: ignore
        return out  # type: ignore
コード例 #12
0
def eval_sparse(
    jaxpr: core.Jaxpr,
    consts: Sequence[Array],  # all consts are dense
    spvalues: Sequence[SparsifyValue],  # mix of sparse and dense pointers into spenv
    spenv: SparsifyEnv,
) -> Sequence[SparsifyValue]:
  env : Dict[core.Var, SparsifyValue] = {}

  def read(var: core.Var) -> Union[Array, SparsifyValue]:
    # all literals are dense
    if isinstance(var, core.Literal):
      return spenv.dense(var.val)
    else:
      return env[var]

  def write_buffer(var: core.Var, a: Array) -> None:
    if isinstance(var, core.DropVar):
      return
    env[var] = spenv.dense(a)

  def write(var: core.Var, a: SparsifyValue) -> None:
    if isinstance(var, core.DropVar):
      return
    assert a is not None
    env[var] = a

  safe_map(write_buffer, jaxpr.constvars, consts)
  safe_map(write, jaxpr.invars, spvalues)

  for eqn in jaxpr.eqns:
    prim = eqn.primitive
    invals = safe_map(read, eqn.invars)

    if any(val.is_sparse() for val in invals):
      if prim not in sparse_rules:
        _raise_unimplemented_primitive(prim)
      out = sparse_rules[prim](spenv, *invals, **eqn.params)
    else:
      if prim is xla.xla_call_p:
        # TODO(vanderplas,frostig): workaround for binding call primitives
        # within a jaxpr interpreter
        params = eqn.params.copy()
        fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
        out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params)
      else:
        out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)
      out_bufs = out_bufs if prim.multiple_results else [out_bufs]
      out = []
      for buf, outvar in safe_zip(out_bufs, eqn.outvars):
        if isinstance(outvar, core.DropVar):
          out.append(None)
        else:
          out.append(spenv.dense(buf))
    safe_map(write, eqn.outvars, out)

  return safe_map(read, jaxpr.outvars)
コード例 #13
0
def evaluate_eqn(eqn, in_values, write_func):
    """Evaluate a single Jax equation and writes the outputs."""
    in_values = list(in_values)
    # This is logic specifically to handle `xla_call`
    call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params)
    if call_jaxpr:
        subfuns = [
            jax.core.lu.wrap_init(
                functools.partial(jax.core.eval_jaxpr, call_jaxpr, ()))
        ]
    else:
        subfuns = []
    ans = eqn.primitive.bind(*(subfuns + in_values), **params)
    if eqn.primitive.multiple_results:
        jax_util.safe_map(write_func, eqn.outvars, ans)
    else:
        write_func(eqn.outvars[0], ans)
    return ans
コード例 #14
0
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
                     cond_nconsts, body_nconsts):
    """Reaps the body of a while loop to get the reaps of the final iteration."""
    cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list(
        tracers, [cond_nconsts, body_nconsts])
    _, init_avals = tree_util.tree_map(lambda x: x.aval,
                                       (body_const_tracers, init_tracers))
    cond_const_vals, body_const_vals, init_vals = tree_util.tree_map(
        lambda x: x.val,
        (cond_const_tracers, body_const_tracers, init_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    body_metadata = _get_harvest_metadata(body_jaxpr, settings,
                                          *(body_const_tracers + init_tracers))
    for k, meta in body_metadata.items():
        mode = meta['mode']
        if mode != 'clobber':
            raise ValueError(
                f'Must use clobber mode for \'{k}\' inside of a `while_loop`.')
    reap_avals = {k: v['aval'] for k, v in body_metadata.items()}

    cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr)
    body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
    reap_settings = dict(tag=settings.tag,
                         allowlist=settings.allowlist,
                         blocklist=settings.blocklist,
                         exclusive=settings.exclusive)

    def new_cond(carry, _):
        return cond_fun(*(cond_const_vals + carry))

    def new_body(carry, _):
        carry, reaps = call_and_reap(
            body_fun, **reap_settings)(*(body_const_vals + carry))
        return (carry, reaps)

    new_in_avals, new_in_tree = tree_util.tree_flatten(
        (init_avals, reap_avals))
    new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_cond, new_in_tree, tuple(new_in_avals))
    new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, new_in_tree, tuple(new_in_avals))
    dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype),
                                         reap_avals)
    new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals))
    out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals),
                           cond_nconsts=len(cond_consts),
                           body_nconsts=len(body_consts),
                           cond_jaxpr=new_cond_jaxpr,
                           body_jaxpr=new_body_jaxpr)
    out = jax_util.safe_map(trace.pure, out)
    out, reaps = tree_util.tree_unflatten(out_tree, out)
    for k, v in reaps.items():
        sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode'])
    return out
コード例 #15
0
ファイル: loops.py プロジェクト: x1489/jax
    def build_output_vals(self, scope, carried_state_names, carried_tree,
                          init_vals, body_closed_jaxpr, body_const_vals):
        # Trace the conditional function. cond_func takes 0 arguments, but
        # for lax.while we need a conditional function that takes the
        # carried_state_names. _initial_style_jaxpr will start its own trace and
        # will create tracers for all the carried state. We must put these values
        # in the scope._mutable_state before we trace the conditional
        # function.
        def cond_func_wrapped(*args):
            assert len(args) == len(carried_state_names)
            for ms, init_ms in zip(carried_state_names, args):
                scope._mutable_state[ms] = init_ms
            res = self.cond_func()
            # Conditional function is not allowed to modify the scope state
            for ms, init_ms in zip(carried_state_names, args):
                if not (scope._mutable_state[ms] is init_ms):
                    raise ValueError(
                        f"Conditional function modifies scope.{ms} field.")
            return res

        init_avals = safe_map(_BodyTracer.abstractify, init_vals)
        cond_jaxpr, cond_consts, cond_tree = (
            lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
                                                  carried_tree,
                                                  tuple(init_avals)))
        # TODO: share these checks with lax_control_flow.while
        if not tree_util.treedef_is_leaf(cond_tree):
            raise TypeError(
                f"cond_fun must return a boolean scalar, but got pytree {cond_tree}."
            )
        if not safe_map(core.typecompat, cond_jaxpr.out_avals,
                        [core.ShapedArray((), np.bool_)]):
            raise TypeError(
                f"cond_fun must return a boolean scalar, but got output type(s) "
                f"{cond_jaxpr.out_avals}.")

        return lax_control_flow.while_p.bind(*itertools.chain(
            cond_consts, body_const_vals, init_vals),
                                             cond_nconsts=len(cond_consts),
                                             cond_jaxpr=cond_jaxpr,
                                             body_nconsts=len(body_const_vals),
                                             body_jaxpr=body_closed_jaxpr)
コード例 #16
0
ファイル: jax_to_tf.py プロジェクト: stilling/jax
  def new_body_tf_func(pred_b: TfVal, *carry: TfVal) -> Sequence[TfVal]:
    new_carry = _interpret_jaxpr(body_jaxpr, *body_consts, *carry)

    def select_one_carry(new_c, c):
      pred_b_bcast = _broadcast_in_dim(pred_b, new_c.shape,
                                       list(range(len(pred_b.shape))))
      return tf.where(pred_b_bcast, new_c, c)

    selected_carry = list(util.safe_map(select_one_carry, new_carry, carry))
    next_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *selected_carry)
    return (next_pred_b, *selected_carry)
コード例 #17
0
def reap_function(main: jax_core.MainTrace, settings: HarvestSettings,
                  return_metadata: bool, args: Iterable[Any]):
    """A function transformation that returns reap values."""
    trace = HarvestTrace(main, jax_core.cur_sublevel())
    in_tracers = jax_util.safe_map(trace.pure, args)
    context = ReapContext(settings, {})
    with trace_util.new_dynamic_context(main, context):
        ans = yield in_tracers, {}
        out_tracers = jax_util.safe_map(trace.full_raise, ans)
        reap_tracers = tree_util.tree_map(lambda x: trace.full_raise(x.value),
                                          context.reaps)
        reap_metadata = tree_util.tree_map(lambda x: x.metadata, context.reaps)
        del main
    out_values, reap_values = tree_util.tree_map(lambda x: x.val,
                                                 (out_tracers, reap_tracers))
    if return_metadata:
        out = (out_values, reap_values, reap_metadata)
    else:
        out = (out_values, reap_values)
    yield out
コード例 #18
0
ファイル: loops.py プロジェクト: Guillem96/jax
 def build_output_vals(self, scope, carried_state_names, carried_tree,
                       init_vals, body_typed_jaxpr, body_const_vals):
   # Simulate a pass-through false branch
   init_avals = safe_map(_BodyTracer.abstractify, init_vals)
   false_body_typed_jaxpr, false_body_const_vals, _ = (
     lax_control_flow._initial_style_jaxpr(lambda *args: args,
                                           carried_tree,
                                           tuple(init_avals)))
   return lax_control_flow.cond_p.bind(
     *itertools.chain([self.pred], body_const_vals,
                      init_vals, false_body_const_vals, init_vals),
     true_jaxpr=body_typed_jaxpr, false_jaxpr=false_body_typed_jaxpr)
コード例 #19
0
 def jaxpr_const_maker(*args, **kwargs):
     # Set up fun for transformation
     wrapped = lu.wrap_init(fun)
     # Flatten input args
     jax_args, in_tree = tree_util.tree_flatten((args, kwargs))
     # Transform fun to accept flat args and return a flat list result
     jaxtree_fun, out_tree = api_util.flatten_fun(wrapped, in_tree)
     # Abstract and partial-val's flat args
     pvals = safe_map(pv_like, jax_args)
     # Trace function into Jaxpr
     jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
     return jaxpr, consts
コード例 #20
0
def _unravel_array_into_pytree(pytree, axis, arr, is_leaf=None):
    leaves, treedef = tree_flatten(pytree, is_leaf=is_leaf)
    axis = axis % arr.ndim
    shapes = [
        arr.shape[:axis] + np.shape(l) + arr.shape[axis + 1:] for l in leaves
    ]
    parts = arr.split(np.cumsum(safe_map(np.size, leaves[:-1])), axis)
    reshaped_parts = [x.reshape(shape) for x, shape in zip(parts, shapes)]
    return tree_unflatten(
        treedef,
        reshaped_parts,
    )
コード例 #21
0
ファイル: jax2tf.py プロジェクト: hereismari/jax
def _tfval_add_unit(vals: Sequence[TfValOrUnit],
                    avals: Sequence[core.AbstractValue]) -> Sequence[TfValOrUnit]:
  """Turn regular TfVals into TfValOrUnit, based on expected abstract values.
  This function is sometimes called with a mix of core.unit and tf.nan in places
  of units.
  """
  def add_unit(v: TfValOrUnit, aval: core.AbstractValue):
    if not core.skip_checks:
      assert ((v is core.unit or tf.math.is_nan(v))
              if aval is core.abstract_unit else _is_tfval(v))
    return core.unit if aval is core.abstract_unit else v
  return util.safe_map(add_unit, vals, avals)
コード例 #22
0
    def wrapped_auto_registered(*args):
        flat_args, _ = jax.tree_flatten(args)
        # Mapping from variable -> value
        env = {}

        read = functools.partial(read_env, env)
        write = functools.partial(write_env, env)

        def tag(var):
            if matches.get(var) is not None:
                inv_map, tagging_func = matches[var]
                var_map = {
                    k: v
                    for k, v in inv_map.items() if not isinstance(k, str)
                }
                val_map = jax.tree_map(read, var_map)
                val = tagging_func(inv_map, val_map)
                env[var] = val

        # Bind args and consts to environment
        write(jax.core.unitvar, jax.core.unit)
        jax_util.safe_map(write, graph.jaxpr.invars, flat_args)
        jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts)

        # Register any orphan parameters as generic
        for param_var in orphan_params:
            write(param_var, tags.register_generic(read(param_var)))

        # Set the correct output variables
        if compute_only_loss_tags:
            output_vars = loss_output_vars
            out_tree = jax.tree_structure(loss_output_vars)
        else:
            output_vars = graph.jaxpr.outvars
            out_tree = graph.out_tree

        # Loop through equations and evaluate primitives using `bind`
        losses_evaluated = 0
        for eqn in graph.jaxpr.eqns:
            evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
            jax_util.safe_map(tag, eqn.outvars)

            # If we want to output only tagged losses
            if isinstance(eqn.primitive, tags.LossTag):
                losses_evaluated += 1
            if compute_only_loss_tags and num_losses == losses_evaluated:
                break

        outputs = jax_util.safe_map(read, output_vars)
        return jax.tree_unflatten(out_tree, outputs)
コード例 #23
0
        def forward():
            own_func_args = func_args
            # Mapping from variable -> value
            env = dict()
            read = functools.partial(tgm.read_env, env)
            write = functools.partial(tgm.write_env, env)

            # Bind args and consts to environment
            write(jax.core.unitvar, jax.core.unit)
            jax_util.safe_map(write, jaxpr.invars,
                              jax.tree_flatten(own_func_args)[0])
            jax_util.safe_map(write, jaxpr.constvars, consts)

            # Loop through equations and evaluate primitives using `bind`
            num_losses_passed = 0
            for eqn in jaxpr.eqns:
                tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars),
                                 write)
                if isinstance(eqn.primitive, tags.LossTag):
                    num_losses_passed += 1
                    if num_losses_passed == len(loss_tags):
                        break
            if num_losses_passed != len(loss_tags):
                raise ValueError("This should be unreachable.")

            return jax_util.safe_map(read, layer_input_vars)
コード例 #24
0
        def forward_aux(aux):
            own_func_args = func_args
            # Mapping from variable -> value
            env = dict()
            read = functools.partial(tgm.read_env, env)

            def write(var, val):
                if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)):
                    val = val + aux[var] if var in aux else val
                env[var] = val

            # Bind args and consts to environment
            write(jax.core.unitvar, jax.core.unit)
            jax_util.safe_map(write, jaxpr.invars,
                              jax.tree_flatten(own_func_args)[0])
            jax_util.safe_map(write, jaxpr.constvars, consts)

            # Loop through equations and evaluate primitives using `bind`
            num_losses_passed = 0
            losses_inputs_values = []
            losses_kwargs_values = []
            for eqn in jaxpr.eqns:
                input_values = jax_util.safe_map(read, eqn.invars)
                tgm.evaluate_eqn(eqn, input_values, write)
                if isinstance(eqn.primitive, tags.LossTag):
                    loss = eqn.primitive.loss(*input_values,
                                              weight=eqn.params["weight"])
                    losses_inputs_values.append(loss.inputs)
                    losses_kwargs_values.append(
                        dict(targets=loss.targets,
                             weight=eqn.params["weight"]))
                    num_losses_passed += 1
                    if num_losses_passed == len(loss_tags):
                        break
            if num_losses_passed != len(loss_tags):
                raise ValueError("This should be unreachable.")
            # Read the inputs to the loss functions, but also return the target values
            return tuple(losses_inputs_values), tuple(losses_kwargs_values)
コード例 #25
0
def eval_jaxpr_with_state(jaxpr: jax_core.Jaxpr, rules: Rules,
                          consts: Sequence[Value], state: Value,
                          *args: Value) -> Tuple[List[Value], Value]:
    """Interprets a JAXpr and manages an input state with primitive rules.

  The implementation follows `jax.core.eval_jaxpr` closely; the main differences
  are:
  1. Rather than always calling `primitive.bind`, `eval_jaxpr_with_state`
     looks up a rule for the primitive in the provided `rules` dictionary first.
  2. A `state` value is provided that is threaded through the execution of the
     JAXpr and whose final value is returned as an additional output.

  Args:
    jaxpr: The JAXpr to be interpreted.
    rules: A `dict` that maps JAX primitives to functions that take in `(state,
      *args)` and return `(output, new_state)`.
    consts: A list of constant values corresponding to the JAXpr's constvars.
    state: The initial state for the interpreter.
    *args: A list of values that correspond to the JAXpr's invars.

  Returns:
    A list of outputs from the JAXpr and the final state.
  """
    env = Environment()

    jax_util.safe_map(env.write, jaxpr.constvars, consts)
    jax_util.safe_map(env.write, jaxpr.invars, args)

    for eqn in jaxpr.eqns:
        invals = jax_util.safe_map(env.read, eqn.invars)
        call_jaxpr, params = jax_core.extract_call_jaxpr(
            eqn.primitive, eqn.params)
        if call_jaxpr:
            call_rule = _effect_handler_call_rules.get(
                eqn.primitive,
                functools.partial(default_call_interpreter_rule,
                                  eqn.primitive))
            ans, state = call_rule(rules, state, invals, call_jaxpr, **params)
        elif eqn.primitive in rules:
            ans, state = rules[eqn.primitive](state, *invals, **params)
        else:
            ans = eqn.primitive.bind(*invals, **params)
        if eqn.primitive.multiple_results:
            jax_util.safe_map(env.write, eqn.outvars, ans)
        else:
            env.write(eqn.outvars[0], ans)
    return jax_util.safe_map(env.read, jaxpr.outvars), state
コード例 #26
0
ファイル: lax_control_flow.py プロジェクト: jonasrauber/jax
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
  """Raises TypeError if (tree1, avals1) does not match (tree2, avals2).

  Corresponding `tree` and `avals` must match in the sense that the number of leaves in
  `tree` must be equal to the length of `avals`.
  `what` will be prepended to details of the mismatch in TypeError.
  """
  if tree1 != tree2:
    msg = ("{} must have same type structure, got {} and {}.")
    raise TypeError(msg.format(what, tree1, tree2))
  if not all(safe_map(typematch, avals1, avals2)):
    msg = ("{} must have identical types, "
           "got {} and {}.")
    raise TypeError(msg.format(what, tree_unflatten(tree1, avals1),
                               tree_unflatten(tree2, avals2)))
コード例 #27
0
def _get_harvest_metadata(closed_jaxpr, settings, *args):
    """Probes a jaxpr for metadata like its sown values."""
    fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr))
    with jax_core.new_main(HarvestTrace) as main:
        settings = HarvestSettings(settings.tag, settings.blocklist,
                                   settings.allowlist, True)
        fun = reap_function(fun, main, settings, True)
        fun, aux = _reap_metadata_wrapper(fun)
        flat_args, in_tree = tree_util.tree_flatten(args)
        flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
        in_avals = jax_util.safe_map(
            lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)),
            flat_args)
        pe.trace_to_jaxpr_final(flat_fun, in_avals)
        metadata = aux()
        out_tree()
    return metadata
コード例 #28
0
ファイル: loops.py プロジェクト: self-supervisor/jax
 def build_output_vals(self, scope, carried_state_names, carried_tree,
                       init_vals, body_closed_jaxpr, body_const_vals):
   # Simulate a pass-through false branch
   in_vals, in_tree = tree_util.tree_flatten(
       (body_const_vals, tree_util.tree_unflatten(carried_tree, init_vals)))
   in_avals = safe_map(_BodyTracer.abstractify, in_vals)
   pass_through_closed_jaxpr, pass_through_const_vals, _ = (
     lax_control_flow._initial_style_jaxpr(
         lambda *args: args[1],
         in_tree,
         tuple(in_avals)))
   assert len(pass_through_const_vals) == 0
   args = list(itertools.chain(body_const_vals, init_vals))
   return lax_control_flow.cond_p.bind(
       self.index, *args,
       branches=(pass_through_closed_jaxpr, body_closed_jaxpr),
       linear=(False,) * len(args))
コード例 #29
0
ファイル: __init__.py プロジェクト: notEvil/jax2numpy
def _get_jax_objects(function, args, kwargs):
    # Set up function for transformation
    wrapped_function = j_linear_util.wrap_init(function)
    # Flatten input arguments
    jax_arguments, in_tree = j_tree_util.tree_flatten((args, kwargs))
    # Transform function to accept flat arguments
    # and return a flat list result
    jaxtree_function, _ = j_api_util.flatten_fun(wrapped_function, in_tree)
    # Abstract and partial-value's flat arguments
    partial_values = j_util.safe_map(_get_partial_value, jax_arguments)
    # Trace function into Jaxpr
    jaxpr, _, constants = ji_partial_eval.trace_to_jaxpr(
        jaxtree_function, partial_values
    )

    result = (jaxpr, constants)
    return result
コード例 #30
0
ファイル: jax2tf.py プロジェクト: saikrishna-1996/jax
def _tfval_add_unit(vals: Sequence[TfValOrUnit],
                    avals: Sequence[core.AbstractValue]) -> Sequence[TfValOrUnit]:
  """Turn regular TfVals into TfValOrUnit, based on expected abstract values.
  When the aval is a unit, the corresponding value is either core.unit,
  or an EagerTensor with the value NaN (we use tf.nan as a concrete TF value
  for units, see _tfval_remove_unit) or may even be a Tensor if we are building
  graphs and the NaN value is abstracted.
  """
  def add_unit(v: TfValOrUnit, aval: core.AbstractValue):
    if not core.skip_checks:
      if aval is core.abstract_unit:
        if v is not core.unit:
          assert isinstance(v, tf.Tensor)
          if v.device:  # Only for EagerTensor
            assert tf.math.is_nan(v)
      else:
        assert _is_tfval(v)
    return core.unit if aval is core.abstract_unit else v
  return util.safe_map(add_unit, vals, avals)