Ejemplo n.º 1
0
def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr,
                         cond_nconsts, body_nconsts):
  cond_const_tracers, body_const_tracers, init_tracers = split_list(
            tracers, [cond_nconsts, body_nconsts])
  init_avals = safe_map(lambda x: x.aval, init_tracers)
  cond_const_vals, body_const_vals, init_vals = tree_map(
      lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers))

  body_fun = jaxpr_as_fun(body_jaxpr)
  cond_fun = jaxpr_as_fun(cond_jaxpr)

  def cond(*carry):
    return cond_fun(*it.chain(cond_const_vals, carry))

  def body(*carry):
    return body_fun(*it.chain(body_const_vals, carry))

  new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls)  # type: ignore
  new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls)  # type: ignore
  in_tree = tree_structure(init_avals)

  new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals))
  new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(new_body, in_tree, tuple(init_avals))
  out = lcf.while_p.bind(
      *it.chain(new_cond_consts, new_body_consts, init_vals),
      cond_nconsts=len(new_cond_consts),
      body_nconsts=len(new_body_consts),
      cond_jaxpr=new_cond_jaxpr,
      body_jaxpr=new_body_jaxpr)
  return safe_map(trace.pure, out)
Ejemplo n.º 2
0
    def start_tracing_body(self):
        """Called upon starting the tracing of the loop body."""
        # TODO: This is the first part of partial_eval.trace_to_subjaxpr. Share.
        self.trace = self.scope.start_subtrace()
        # The entire state is carried.
        self.carried_state_names = sorted(self.scope._mutable_state.keys())
        for key in self.carried_state_names:
            init_val = self.scope._mutable_state[key]
            flat_init_vals, init_tree = tree_util.tree_flatten(init_val)
            flat_init_avals = safe_map(_BodyTracer.abstractify, flat_init_vals)
            flat_init_pvals = safe_map(pe.PartialVal.unknown, flat_init_avals)
            flat_init_vars = safe_map(self.trace.new_arg, flat_init_pvals)
            self.carried_state_vars[key] = flat_init_vars
            # Set the scope._mutable_state to new tracing variables.
            self.scope._mutable_state[key] = init_tree.unflatten(
                flat_init_vars)
            self.scope._mutable_state_aval[key] = init_tree.unflatten(
                flat_init_avals)
            # Make a copy of the initial state by unflattening the flat_init_vals
            self.carried_state_initial[key] = init_tree.unflatten(
                flat_init_vals)

        index_var_aval = _BodyTracer.abstractify(0)
        index_var_pval = pe.PartialVal.unknown(index_var_aval)
        self._index_var = self.trace.new_arg(index_var_pval)
Ejemplo n.º 3
0
 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True):
   # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
   instantiate = [instantiate] * len(out_tracers)
   out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
   out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                          instantiate, out_tracers)
   jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
   assert not env  # TODO: this is from partial_eval.trace_to_jaxpr. Share.
   closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   return closed_jaxpr, consts
Ejemplo n.º 4
0
    def __setattr__(self, key, value):
        """Update scope data to be functionalized.

    Called for *all* attribute setting.
    """
        if key in [
                "_active_ranges", "_mutable_state", "_mutable_state_aval",
                "_count_subtraces"
        ]:
            object.__setattr__(self, key, value)
        else:
            if self._active_ranges:
                if key not in self._mutable_state:
                    raise ValueError(
                        "New mutable state '{}' cannot be created inside a loop."
                        .format(key))
                assert key in self._mutable_state_aval
                old_aval = self._mutable_state_aval[key]
                flat_values, flat_tree = tree_util.tree_flatten(value)
                new_aval = flat_tree.unflatten(
                    safe_map(_BodyTracer.abstractify, flat_values))
                if old_aval != new_aval:
                    msg = (
                        f"Mutable state '{key}' is updated with new abstract value "
                        f"{new_aval}, which is different from previous one {old_aval}"
                    )
                    raise TypeError(msg)
            self._mutable_state[key] = value
Ejemplo n.º 5
0
def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers,
                                                fun_jaxpr, num_consts, **params):
  main = trace.main
  vals = [t.val for t in tracers]

  new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls)
  if primitive == cd.custom_jvp_call_jaxpr_p:
    thunk_name = 'jvp_jaxpr_thunk'
  elif primitive == cd.custom_vjp_call_jaxpr_p:
    thunk_name = 'fwd_jaxpr_thunk'
    params['bwd'] = callback_subtrace(params['bwd'], main)
  else:
    raise NotImplementedError(primitive)

  thunk = params.pop(thunk_name)
  @pe._memoize
  def new_thunk():
    thunk_jaxpr = core.ClosedJaxpr(*thunk())
    closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls)
    return closed_jaxpr.jaxpr, closed_jaxpr.literals

  params[thunk_name] = new_thunk
  new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals
  closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(new_fun_jaxpr), ())
  new_num_consts = len(new_consts) + num_consts
  out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr,
                       num_consts=new_num_consts, **params)
  return safe_map(trace.pure, out)
Ejemplo n.º 6
0
def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry,
                        jaxpr, linear, unroll):
  const_tracers, carry_tracers, xs_tracers = split_list(tracers, [num_consts, num_carry])
  carry_avals, xs_avals = tree_map(lambda x: x.aval, (carry_tracers, xs_tracers))
  const_vals, carry_vals, xs_vals = tree_map(lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))

  x_tracers = [t[0] for t in xs_tracers]
  x_avals = [t.aval for t in x_tracers]

  body_fun = jaxpr_as_fun(jaxpr)

  def new_body(*vals):
    out = body_fun(*vals)
    out_carry, y = split_list(out, [num_carry])
    return out_carry, y
  new_body = callback_transform(new_body, trace.callback,
                                strip_calls=trace.strip_calls)  # type: ignore
  in_tree = tree_structure(carry_avals + xs_avals)
  new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr(
      new_body, in_tree, tuple(carry_avals + x_avals))
  vals = tuple(it.chain(new_consts, carry_vals, xs_vals))
  out_vals = lax.scan_p.bind(*vals, reverse=reverse, length=length,
                             num_consts=len(new_consts), num_carry=num_carry,
                             jaxpr=new_jaxpr, linear=linear, unroll=unroll)
  return safe_map(trace.pure, out_vals)
Ejemplo n.º 7
0
 def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
                             out_trees):
   vals_in = [t.val for t in tracers]
   fun = callback_subtrace(fun, self.main)
   fwd = callback_subtrace(fwd, self.main)
   bwd = callback_subtrace(bwd, self.main)
   out = primitive.bind(fun, fwd, bwd, *vals_in, out_trees=out_trees)
   return safe_map(self.pure, out)
Ejemplo n.º 8
0
  def end_tracing_body(self):
    """Called when we are done tracing one iteration of the body."""
    # We will turn the body of the loop into a function that takes some values
    # for the scope state (carried_state_names) and returns the values for the
    # same state fields after one execution of the body. For some of the ranges,
    # e.g., scope.range, the function will also take the index_var as last parameter.
    in_tracers = tuple(itertools.chain(*[self.carried_state_vars[ms] for ms in self.carried_state_names]))
    if self.loop_builder.can_use_index_var():
      in_tracers += (self._index_var,)

    # Make the jaxpr for the body of the loop
    # TODO: See which mutable state was changed in the one iteration.
    # For now, we assume all state changes.
    body_out_tracers = []
    for key in self.carried_state_names:
      new_val = self.scope._mutable_state[key]
      flat_new_values, flat_new_tree = tree_util.tree_flatten(new_val)
      body_out_tracers.extend(flat_new_values)
      assert key in self.scope._mutable_state_aval
      old_aval = self.scope._mutable_state_aval[key]
      new_aval = flat_new_tree.unflatten(safe_map(_BodyTracer.abstractify, flat_new_values))
      if old_aval != new_aval:
        msg = (f"Mutable state '{key}' had at the end of the loop body new abstract value "
               f"{new_aval}, which is different from initial one {old_aval}")
        raise TypeError(msg)

    try:
      # If the body actually uses the index variable, and is not allowed to
      # (e.g., cond_range and while_range), then in_tracers will not contain
      # the tracer for the index_var, and trace_to_jaxpr_finalize will throw
      # an assertion error.
      body_closed_jaxpr, body_const_vals = _BodyTracer.trace_to_jaxpr_finalize(
          in_tracers=in_tracers,
          out_tracers=body_out_tracers,
          trace=self.trace)
    except UnexpectedTracerError as e:
      if "Tracer not among input tracers" in str(e):
        raise ValueError("Body of cond_range or while_range should not use the "
                         "index variable returned by iterator.") from e
      raise
    # End the subtrace for the loop body, before we trace the condition
    self.scope.end_subtrace()

    carried_init_val = tuple([self.carried_state_initial[ms]
                              for ms in self.carried_state_names])
    carried_init_vals, carried_tree = tree_util.tree_flatten(carried_init_val)
    assert len(carried_init_vals) == len(body_out_tracers)

    carried_out_vals = self.loop_builder.build_output_vals(
      self.scope, self.carried_state_names, carried_tree,
      carried_init_vals, body_closed_jaxpr, body_const_vals)
    carried_mutable_state_unflattened = tree_util.tree_unflatten(carried_tree,
                                                                 carried_out_vals)

    # Update the mutable state with the values of the changed vars, after the loop.
    for ms, mv in zip(self.carried_state_names, carried_mutable_state_unflattened):
      self.scope._mutable_state[ms] = mv
Ejemplo n.º 9
0
Archivo: loops.py Proyecto: wayfeng/jax
    def build_output_vals(self, scope, carried_state_names, carried_tree,
                          init_vals, body_closed_jaxpr, body_const_vals):
        # Trace the conditional function. cond_func takes 0 arguments, but
        # for lax.while we need a conditional function that takes the
        # carried_state_names. _initial_style_jaxpr will start its own trace and
        # will create tracers for all the carried state. We must put these values
        # in the scope._mutable_state before we trace the conditional
        # function.
        def cond_func_wrapped(*args):
            assert len(args) == len(carried_state_names)
            for ms, init_ms in zip(carried_state_names, args):
                scope._mutable_state[ms] = init_ms
            res = self.cond_func()
            # Conditional function is not allowed to modify the scope state
            for ms, init_ms in zip(carried_state_names, args):
                if not (scope._mutable_state[ms] is init_ms):
                    raise ValueError(
                        f"Conditional function modifies scope.{ms} field.")
            return res

        init_avals = safe_map(_BodyTracer.abstractify, init_vals)
        cond_jaxpr, cond_consts, cond_tree = (
            lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
                                                  carried_tree,
                                                  tuple(init_avals)))
        # TODO: share these checks with lax_control_flow.while
        if not tree_util.treedef_is_leaf(cond_tree):
            raise TypeError(
                f"cond_fun must return a boolean scalar, but got pytree {cond_tree}."
            )
        if not safe_map(core.typecompat, cond_jaxpr.out_avals,
                        [core.ShapedArray((), np.bool_)]):
            raise TypeError(
                f"cond_fun must return a boolean scalar, but got output type(s) "
                f"{cond_jaxpr.out_avals}.")

        return lax_control_flow.while_p.bind(*cond_consts,
                                             *body_const_vals,
                                             *init_vals,
                                             cond_nconsts=len(cond_consts),
                                             cond_jaxpr=cond_jaxpr,
                                             body_nconsts=len(body_const_vals),
                                             body_jaxpr=body_closed_jaxpr)
Ejemplo n.º 10
0
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts,
                          name, call_jaxpr, local_in_parts,
                          local_out_parts_thunk, local_nparts):
    # We assume any extra leading in_nodes are constants and replicate them.
    num_extra_nodes = len(in_nodes) - len(in_parts)
    assert num_extra_nodes >= 0
    in_parts = (None, ) * num_extra_nodes + in_parts

    args = []
    for ns, sharding in safe_zip(
            safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts):
        if sharding is not None:
            args.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            args.append(ns)

    sub_ctx = ctx.module_context.replace(
        name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
    fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
                                 core.ClosedJaxpr(call_jaxpr, ()))

    output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
    flat_output_types = util.flatten(output_types)
    call = std.CallOp(flat_output_types,
                      ir.FlatSymbolRefAttr.get(fn.name.value),
                      mlir.flatten_lowering_ir_args(args))
    out_nodes = util.unflatten(call.results, safe_map(len, output_types))

    out_parts = out_parts_thunk()
    outputs = []
    for ns, sharding in safe_zip(out_nodes, out_parts):
        if sharding is not None:
            outputs.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            outputs.append(ns)
    return outputs
Ejemplo n.º 11
0
 def build_output_vals(self, scope, carried_state_names, carried_tree,
                       init_vals, body_closed_jaxpr, body_const_vals):
   # Simulate a pass-through false branch
   in_vals, in_tree = tree_util.tree_flatten(
       (body_const_vals, tree_util.tree_unflatten(carried_tree, init_vals)))
   in_avals = safe_map(_BodyTracer.abstractify, in_vals)
   pass_through_closed_jaxpr, pass_through_const_vals, _ = (
     lax_control_flow._initial_style_jaxpr(
         lambda *args: args[1],
         in_tree,
         tuple(in_avals)))
   assert len(pass_through_const_vals) == 0
   args = [*body_const_vals, *init_vals]
   return lax_control_flow.cond_p.bind(
       self.index, *args,
       branches=(pass_through_closed_jaxpr, body_closed_jaxpr),
       linear=(False,) * len(args))
Ejemplo n.º 12
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    axis_env = xla.AxisEnv(nrep, (), ())
    unordered_effects = [
        eff for eff in jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in jaxpr.effects if eff in core.ordered_effects
    ]
    module, _ = mlir.lower_jaxpr_to_module(
        f"spjit_{fun.__name__}",
        core.ClosedJaxpr(jaxpr, consts),
        unordered_effects,
        ordered_effects,
        platform=platform,
        axis_context=mlir.ReplicaAxisContext(axis_env),
        name_stack=new_name_stack(wrap_name(name, "sharded_jit")),
        donated_args=[False] * len(in_parts),
        arg_shardings=safe_map(xla.sharding_to_proto, in_parts),
        result_shardings=safe_map(xla.sharding_to_proto, out_parts))
    built = xc._xla.mlir.mlir_module_to_xla_computation(
        mlir.module_to_string(module), use_tuple_args=False, return_tuple=True)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
Ejemplo n.º 13
0
 def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
   vals_in = [t.val for t in tracers]
   fun = callback_subtrace(fun, self.main)
   jvp = callback_subtrace(jvp, self.main)
   out = primitive.bind(fun, jvp, *vals_in)
   return safe_map(self.pure, out)