예제 #1
0
파일: ad.py 프로젝트: John1Tang/jax
def _update_annotation(f: lu.WrappedFun,
                       orig_type: Optional[Tuple[Tuple[core.AbstractValue,
                                                       bool], ...]],
                       nonzeros: List[bool]) -> lu.WrappedFun:
    if orig_type is None:
        return f
    tan_types = [(aval.at_least_vspace(), keep)
                 for nz, (aval, keep) in zip(nonzeros, orig_type) if nz]
    return lu.annotate(f, (*orig_type, *tan_types))
예제 #2
0
파일: batching.py 프로젝트: xueeinstein/jax
def _update_annotation(
    f: lu.WrappedFun,
    orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]],
    axis_size: int, axis_name: core.AxisName, in_dims: Sequence[Optional[int]]
  ) -> lu.WrappedFun:
  if orig_type is None:
    return f
  batched_in_type = [(core.unmapped_aval(axis_size, axis_name, dim, aval), keep)
                     for dim, (aval, keep) in zip(in_dims, orig_type)]
  return lu.annotate(f, tuple(batched_in_type))
예제 #3
0
def _update_annotation(
    f: lu.WrappedFun,
    orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]],
    nonzeros: List[bool]
  ) -> lu.WrappedFun:
  if orig_type is None:
    return f
  # Implicit arguments never have tangents, so generate the tangent part of the
  # type annotation from explicit arguments only.
  orig_avals = [aval for aval, explicit in orig_type if explicit]
  tan_types = [(aval.at_least_vspace(), True)
               for nz, aval in zip(nonzeros, orig_avals) if nz]
  return lu.annotate(f, (*orig_type, *tan_types))
예제 #4
0
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
  all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
  fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
                            reduce_axes, False)
  fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
  if not config.jax_experimental_name_stack:
    params = dict(params, name=wrap_name(params['name'], 'transpose'))
  update_params = call_transpose_param_updaters.get(primitive)
  if update_params:
    params = update_params(params, map(is_undefined_primal, args),
                           [type(x) is not Zero for x in ct])
  if config.jax_dynamic_shapes:
    in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args]
    fun = lu.annotate(fun, tuple(in_type))
  out_flat = primitive.bind(fun, *all_args, **params)
  return tree_unflatten(out_tree(), out_flat)
예제 #5
0
파일: dispatch.py 프로젝트: romanngg/jax
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
                       donated_invars, always_lower: bool, keep_unused: bool,
                       *arg_specs):
    """Lower into XLA.

  Args:
    always_lower: If `True`, even trivial programs (not doing any computation
      such as lambda x: x) will be lowered into an XLA program.
    keep_unused: If `False` (the default), arguments that JAX determines to be
      unused by `fun` *may* be dropped from resulting compiled XLA executables.
      Such arguments will not be transferred to the device nor provided to the
      underlying executable. If `True`, unused arguments will not be pruned.
  """
    if device is not None and backend is not None:
        raise ValueError("can't specify both a device and a backend for jit, "
                         "got device={} and backend={}".format(
                             device, backend))
    abstract_args, arg_devices = util.unzip2(arg_specs)
    if fun.in_type is None:
        # Add an annotation inferred from the arguments; no dynamic axes here.
        in_type = tuple(unsafe_zip(abstract_args, itertools.repeat(True)))
        fun = lu.annotate(fun, in_type)
    else:
        assert all(
            map(core.typematch, abstract_args,
                [a for a, k in fun.in_type if k]))
    with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
                          "for jit in {elapsed_time} sec"):
        jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
            fun, pe.debug_info_final(fun, "jit"))
    out_avals, kept_outputs = util.unzip2(out_type)
    if any(isinstance(c, core.Tracer) for c in consts):
        raise UnexpectedTracerError("Encountered an unexpected tracer.")

    if config.jax_dynamic_shapes:
        keep_unused = True
    else:
        jaxpr = apply_outfeed_rewriter(jaxpr)

    if not keep_unused:
        jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
        consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
        abstract_args, arg_devices = util.unzip2(
            [a for i, a in enumerate(arg_specs) if i in kept_var_idx])
        donated_invars = [
            x for i, x in enumerate(donated_invars) if i in kept_var_idx
        ]
        del kept_const_idx
    else:
        kept_var_idx = set(range(len(abstract_args)))

    nreps = jaxpr_replicas(jaxpr)
    device = _xla_callable_device(nreps, backend, device, arg_devices)
    backend = xb.get_device_backend(device) if device else xb.get_backend(
        backend)

    if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr)
            and not _backend_supports_unbounded_dynamic_shapes(backend)):
        jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)

    map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))

    # Computations that only produce constants and/or only rearrange their inputs,
    # which are often produced from partial evaluation, don't need compilation,
    # and don't need to evaluate their arguments.
    if not jaxpr.eqns and not always_lower and all(kept_outputs):
        return XlaComputation(name,
                              None,
                              True,
                              None,
                              None,
                              None,
                              jaxpr=jaxpr,
                              consts=consts,
                              device=device,
                              in_avals=abstract_args,
                              out_avals=out_avals,
                              has_unordered_effects=False,
                              ordered_effects=[],
                              kept_var_idx=kept_var_idx,
                              keepalive=None)

    if not _on_exit:
        log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
        if len(abstract_args) > 10:
            msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
        else:
            msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
        logging.log(log_priority, msg)

    if nreps > 1:
        warnings.warn(
            f"The jitted function {name} includes a pmap. Using "
            "jit-of-pmap can lead to inefficient data movement, as the outer jit "
            "does not preserve sharded data representations and instead collects "
            "input and output arrays onto a single device. "
            "Consider removing the outer jit unless you know what you're doing. "
            "See https://github.com/google/jax/issues/2926.")

    if nreps > xb.device_count(backend):
        raise ValueError(
            f"compiling computation `{name}` that requires {nreps} replicas, but "
            f"only {xb.device_count(backend)} XLA devices are available.")

    if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
        raise NotImplementedError(
            "jit of multi-host pmap not implemented (and jit-of-pmap can cause "
            "extra data movement anyway, so maybe you don't want it after all)."
        )

    # pass long arg lists as tuple for TPU
    tuple_args = len(abstract_args) > 100
    axis_env = xla.AxisEnv(nreps, (), ())
    name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
    closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
    module_name = f"jit_{fun.__name__}"
    unordered_effects = [
        eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in closed_jaxpr.effects if eff in core.ordered_effects
    ]
    module, keepalive = mlir.lower_jaxpr_to_module(
        module_name, closed_jaxpr,
        unordered_effects, ordered_effects, backend.platform,
        mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
    return XlaComputation(name,
                          module,
                          False,
                          donated_invars,
                          fun.in_type,
                          out_type,
                          nreps=nreps,
                          device=device,
                          backend=backend,
                          tuple_args=tuple_args,
                          in_avals=abstract_args,
                          out_avals=out_avals,
                          has_unordered_effects=bool(unordered_effects),
                          ordered_effects=ordered_effects,
                          kept_var_idx=kept_var_idx,
                          keepalive=keepalive)