コード例 #1
0
ファイル: named_call.py プロジェクト: shafiahmed/dm-haiku
    def named_fun(*args, **kwargs):
        # Wrap and flatten f for JAX internals.
        f = lu.wrap_init(fun)
        flat_args, in_tree = jax.tree_flatten((args, kwargs))
        flat_f, out_tree = api.flatten_fun(f, in_tree)

        # Hide any args that are not a valid JaxType by partially applying flat_f
        dyn_argnums = [
            i for (i, x) in enumerate(flat_args) if jax.api._valid_jaxtype(x)
        ]  # pylint: disable=protected-access
        part_flat_f, dyn_args = jax.argnums_partial(flat_f, dyn_argnums,
                                                    flat_args)

        # Call f with a custom XLA subcomputation via named_call & unflatten result.
        out_flat = named_call_p.bind(part_flat_f, *dyn_args, name=name)
        return jax.tree_unflatten(out_tree(), out_flat)
コード例 #2
0
  def named_fun(*args, **kwargs):
    f = lu.wrap_init(fun)
    flat_args, in_tree = jax.tree_flatten((args, kwargs))
    # flat_f is a version of f that takes a flat sequence of args and returns
    # a flat sequence of results.
    flat_f, out_tree = api.flatten_fun(f, in_tree)

    # fun can be called with a combination of concrete and abstract args if
    # called in a Tracer context (eg. if using static_argnums at an outer level)
    # The concrete args are stowed away by partially applying fun, the abstract
    # (or dynamic) args will be explicitly passed and will be visible in the
    # Tracer context.
    dyn_argnums = [i for (i, x) in enumerate(flat_args)
                   if isinstance(x, core.Tracer)]
    part_flat_f, dyn_args = jax.argnums_partial(flat_f, dyn_argnums, flat_args)

    # call the partially applied function on the dynamic args with a custom name
    # using the named_call primitive
    out_flat = named_call_p.bind(part_flat_f, *dyn_args, name=name)

    # pack the flat result back into the same structure fun would have returned
    return jax.tree_unflatten(out_tree(), out_flat)
コード例 #3
0
ファイル: named_call.py プロジェクト: qiuminxu/dm-haiku
    def named_fun(*args, **kwargs):
        # Wrap and flatten f for JAX internals.
        f = lu.wrap_init(fun)
        flat_args, in_tree = jax.tree_flatten((args, kwargs))
        flat_f, out_tree = api.flatten_fun(f, in_tree)

        if config.omnistaging_enabled:
            # Avoid abstracting inputs by calling as a thunk
            f_thunk = lu.wrap_init(lambda: flat_f.call_wrapped(*flat_args), )
            out_flat = named_call_p.bind(f_thunk, name=name)
        else:
            # Hide any args that are not a valid JaxType by partially applying flat_f
            dyn_argnums = [
                i for (i, x) in enumerate(flat_args)
                if jax.api._valid_jaxtype(x)
            ]  # pylint: disable=protected-access
            part_f, dyn_args = jax.argnums_partial(flat_f, dyn_argnums,
                                                   flat_args)

            # Call with a custom XLA subcomputation via named_call & unflatten result.
            out_flat = named_call_p.bind(part_f, *dyn_args, name=name)

        return jax.tree_unflatten(out_tree(), out_flat)