Ejemplo n.º 1
0
 def fun_remat(*args, **kwargs):
   args_flat, in_tree = tree_flatten((args, kwargs))
   flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(fun, in_tree, False, "checkpoint")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   out_flat = remat_p.bind(
       *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr),
       prevent_cse=prevent_cse, differentiated=False, policy=policy)
   return tree_unflatten(out_tree(), out_flat)
Ejemplo n.º 2
0
    def wrapped(*args, **kwargs):
        if kwargs:
            raise NotImplementedError(
                "sharded_jit over kwargs not yet supported")

        f = lu.wrap_init(fun)
        if static_argnums:
            if max(static_argnums) >= len(args):
                raise ValueError(
                    f"jitted function has static_argnums={static_argnums}"
                    f" but was called with only {len(args)} positional "
                    f"argument{'s' if len(args) > 1 else ''}. "
                    "All static broadcasted arguments must be passed positionally."
                )
            dyn_argnums = [
                i for i in range(len(args)) if i not in static_argnums
            ]
            f, args = argnums_partial(f, dyn_argnums, args)

        args_flat, in_tree = tree_flatten((args, kwargs))
        in_parts_flat = tuple(
            flatten_axes("sharded_jit in_parts",
                         in_tree.children()[0], in_parts))
        if local_in_parts is not None:
            local_in_parts_flat = tuple(
                flatten_axes("sharded_jit local_in_parts",
                             in_tree.children()[0], local_in_parts))
        else:
            local_in_parts_flat = None

        flat_fun, out_tree = flatten_fun(f, in_tree)
        # TODO(skye): having a function-typed param in a primitive seems dicey, is
        # there a better way?
        out_parts_thunk = HashableFunction(lambda: tuple(
            flatten_axes("sharded_jit out_parts", out_tree(), out_parts)),
                                           closure=out_parts)
        if local_out_parts:
            local_out_parts_thunk = HashableFunction(lambda: tuple(
                flatten_axes("sharded_jit local_out_parts", out_tree(),
                             local_out_parts)),
                                                     closure=local_out_parts)
        else:
            local_out_parts_thunk = HashableFunction(lambda: None,
                                                     closure=None)

        out = sharded_call(flat_fun,
                           *args_flat,
                           nparts=nparts,
                           in_parts=in_parts_flat,
                           out_parts_thunk=out_parts_thunk,
                           local_in_parts=local_in_parts_flat,
                           local_out_parts_thunk=local_out_parts_thunk,
                           local_nparts=local_nparts,
                           name=flat_fun.__name__)
        return tree_unflatten(out_tree(), out)
Ejemplo n.º 3
0
    def wrapper(f_orig):
        if len(constant_arg_nums) > 0:
            # Reordering args so the ones to remove are given first
            # This will allow us to return a function that has completely removed those args
            # Moreover, we do it here so this reordering will be optimized by the compiler
            def f(*args):
                new_args = tuple(args[k] for k in reorder)
                return f_orig(*new_args)

            # Create the partial args needed by trace_to_jaxpr
            def get_arg(a, unknown):
                if unknown:
                    return tree_flatten(
                        (
                            tree_map(
                                lambda x: PartialVal.unknown(get_aval(x).at_least_vspace()), a
                            ),
                            {},
                        )
                    )[0]
                else:
                    return PartialVal.known(a)

            part_args = []
            for k, a in enumerate(dark):
                temp = get_arg(a, k >= num_args_remove)
                if isinstance(temp, list):
                    part_args += temp
                else:
                    part_args.append(temp)
            part_args = tuple(part_args)

            # Create jaxpr
            wrap = lu.wrap_init(f)
            _, in_tree = tree_flatten((dark, {}))
            wrap_flat, out_tree = flatten_fun(wrap, in_tree)
            jaxpr, _, const = trace_to_jaxpr(wrap_flat, part_args)

            # Create new, partially evaluated function
            if out_tree().num_leaves == 1 and out_tree().num_nodes == 1:
                # out_tree() is PyTreeDef(*), so just return the value. Since eval_jaxpr returns a list,
                # this is just value [0]
                f_removed = lambda *args: eval_jaxpr(
                    jaxpr, const, *tree_flatten((*dark[0:num_args_remove], *args, {}))[0]
                )[0]
            else:
                # Use out_tree() to reshape the args correctly.
                f_removed = lambda *args: out_tree().unflatten(
                    eval_jaxpr(
                        jaxpr, const, *tree_flatten((*dark[0:num_args_remove], *args, {}))[0]
                    )
                )
            return f_removed
        else:
            return f_orig
Ejemplo n.º 4
0
def make_jaxpr(fun: Callable, in_tree: PyTreeDef,
               in_avals: typing.Tuple[core.AbstractValue],  # with DBIdx in them
               keep_inputs: typing.Tuple[bool]
               ) -> typing.Tuple[core.Jaxpr, List[Any], PyTreeDef]:
  flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
  debug = pe.debug_info(fun, in_tree, False, "dex_jit")
  jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug,
                                                keep_inputs=keep_inputs)
  jaxpr = pe.convert_constvars_jaxpr(jaxpr_)
  consts = [_canonicalize_arg(c) for c in consts]
  return jaxpr, consts, out_tree()
Ejemplo n.º 5
0
def linearize(traceable, *primals, **kwargs):
  has_aux = kwargs.pop('has_aux', False)
  if not has_aux:
    jvpfun = jvp(traceable)
  else:
    jvpfun, aux = jvp(traceable, has_aux=True)

  in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
              + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
                    for p in primals))
  _, in_tree = tree_flatten(((primals, primals), {}))
  jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
  assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
  out_primals_consts = [pval.get_known() for pval in out_primals_pvals]
  if not has_aux:
    return out_primals_consts, out_tangents_pvals, jaxpr, consts
  else:
    return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()