コード例 #1
0
ファイル: xla.py プロジェクト: John1Tang/jax
def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
                  consts: Sequence[XlaOp], *args: XlaOp) -> Sequence[XlaOp]:
  assert ctx.platform is not None
  def read(v):
    if type(v) is Literal:
      return pyval_to_ir_constants(ctx.builder, canonicalize_dtype(v.val))
    else:
      return env[v]

  def aval(v):
    if type(v) is Literal:
      return abstractify(v.val)
    else:
      return v.aval

  def write(v, node):
    assert node is not None
    env[v] = node

  env: Dict[core.Var, Sequence[XlaOp]] = {}
  _partitionmap(write, [core.unitvar],
                pyval_to_ir_constants(ctx.builder, core.unit))
  _partitionmap(write, jaxpr.constvars, consts)
  _partitionmap(write, jaxpr.invars, args)
  for eqn in jaxpr.eqns:
    if config.jax_experimental_name_stack:
      assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
      source_info = eqn.source_info.replace(
          name_stack=ctx.name_stack + eqn.source_info.name_stack)
    else:
      source_info = eqn.source_info
    op_metadata = make_op_metadata(
        eqn.primitive, eqn.params, name_stack=ctx.name_stack,
        source_info=source_info)
    ctx.builder.set_op_metadata(op_metadata)
    in_nodes = _flatmap(read, eqn.invars)
    if (ctx.platform is not None and
        eqn.primitive in _backend_specific_translations[ctx.platform]):
      rule = _backend_specific_translations[ctx.platform][eqn.primitive]
    elif eqn.primitive in _translations:
      rule = _translations[eqn.primitive]
    else:
      raise NotImplementedError(
          f"XLA translation rule for primitive '{eqn.primitive.name}' not found")

    with source_info_util.user_context(eqn.source_info.traceback):
      eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if
          config.jax_experimental_name_stack else ctx)
      ans = rule(eqn_ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
                 *in_nodes, **eqn.params)

    assert isinstance(ans, collections.abc.Sequence), (ans, eqn)
    assert all(isinstance(x, xe.XlaOp) for x in ans), (ans, eqn)
    map(ctx.builder.get_shape, ans)  # force xla to do shape error checking
    ctx.builder.clear_op_metadata()
    _partitionmap(write, eqn.outvars, ans)
  return _flatmap(read, jaxpr.outvars)
コード例 #2
0
ファイル: ad.py プロジェクト: yang-song/jax
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
    if all(type(ct) is Zero for ct in cotangents_in):
        return map(lambda v: Zero(v.aval), jaxpr.invars)

    def write_cotangent(prim, v, ct):
        # assert v not in primal_env
        assert ct is not Zero, (prim, v.aval
                                )  # check for an old harmless type error
        if ct is None or type(v) is Literal:
            return
        if type(ct) is Zero:
            # FIXME: This triggers a lot of failures!
            # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
            return
        ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
        if not core.skip_checks:
            ct_aval = core.get_aval(ct_env[v])
            joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type()
            assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval,
                                                             ct_aval)

    def read_cotangent(v):
        return ct_env.get(v, Zero(v.aval))

    def read_primal(v):
        if type(v) is Literal:
            return v.val
        else:
            return primal_env.get(v, UndefinedPrimal(v.aval))

    def write_primal(v, val):
        if not is_undefined_primal(val):
            primal_env[v] = val

    primal_env: Dict[Any, Any] = {}
    write_primal(core.unitvar, core.unit)
    map(write_primal, jaxpr.constvars, consts)
    # FIXME: invars can contain both primal and tangent values, and this line
    #        forces primal_in to contain UndefinedPrimals for tangent values!
    map(write_primal, jaxpr.invars, primals_in)

    # Find the last use of each cotangent so that they can be removed
    # as soon as possible.
    drop_cts: List[Set[Any]] = []
    seen_vars: Set[Any] = set(jaxpr.invars)
    for eqn in jaxpr.eqns:
        read_set = set(eqn.outvars)  # NOTE: eqn is not transposed yet!
        drop_cts.append(read_set - seen_vars)
        seen_vars |= read_set

    ct_env: Dict[Any, Any] = {}
    map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
    for eqn, to_drop in zip(jaxpr.eqns[::-1], drop_cts[::-1]):
        # FIXME: Some invars correspond to tangents
        invals = map(read_primal, eqn.invars)
        if eqn.primitive.multiple_results:
            cts_in = map(read_cotangent, eqn.outvars)
        else:
            cts_in, = map(read_cotangent, eqn.outvars)
        with source_info_util.user_context(eqn.source_info):
            if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
                cts_in_avals = [v.aval for v in eqn.outvars]
                call_jaxpr, params = core.extract_call_jaxpr(
                    eqn.primitive, eqn.params)
                cts_out = get_primitive_transpose(eqn.primitive)(params,
                                                                 call_jaxpr,
                                                                 invals,
                                                                 cts_in,
                                                                 cts_in_avals)
            else:
                cts_out = get_primitive_transpose(eqn.primitive)(cts_in,
                                                                 *invals,
                                                                 **eqn.params)
        cts_out = [Zero(v.aval)
                   for v in eqn.invars] if cts_out is Zero else cts_out
        # FIXME: Some invars correspond to primals!
        map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
        for var in to_drop:
            ct_env.pop(var, None)  # NB: Constant cotangents might be missing

    cotangents_out = map(read_cotangent, jaxpr.invars)
    return cotangents_out
コード例 #3
0
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
    if all(type(ct) is Zero for ct in cotangents_in):
        return map(lambda v: Zero(v.aval), jaxpr.invars)

    def write_cotangent(prim, v, ct):
        # assert v not in primal_env
        assert ct is not Zero, (prim, v.aval
                                )  # check for an old harmless type error
        if ct is None or type(v) is Literal:
            return
        if type(ct) is Zero:
            # FIXME: This triggers a lot of failures!
            # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
            return
        ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
        if config.jax_enable_checks:
            ct_aval = core.get_aval(ct_env[v])
            joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type()
            assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval,
                                                             ct_aval)

    def read_cotangent(v):
        return ct_env.pop(v, Zero(v.aval))

    def read_primal(v):
        if type(v) is Literal:
            return v.val
        else:
            return primal_env.get(v, UndefinedPrimal(v.aval))

    def write_primal(v, val):
        if not is_undefined_primal(val):
            primal_env[v] = val

    primal_env: Dict[Any, Any] = {}
    write_primal(core.unitvar, core.unit)
    map(write_primal, jaxpr.constvars, consts)
    # FIXME: invars can contain both primal and tangent values, and this line
    #        forces primal_in to contain UndefinedPrimals for tangent values!
    map(write_primal, jaxpr.invars, primals_in)

    ct_env: Dict[Any, Any] = {}
    map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
    for eqn in jaxpr.eqns[::-1]:
        # FIXME: Some invars correspond to tangents
        invals = map(read_primal, eqn.invars)
        if eqn.primitive.multiple_results:
            cts_in = map(read_cotangent, eqn.outvars)
        else:
            cts_in, = map(read_cotangent, eqn.outvars)
        with source_info_util.user_context(eqn.source_info):
            if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
                cts_in_avals = [v.aval for v in eqn.outvars]
                call_jaxpr, params = core.extract_call_jaxpr(
                    eqn.primitive, eqn.params)
                cts_out = get_primitive_transpose(eqn.primitive)(params,
                                                                 call_jaxpr,
                                                                 invals,
                                                                 cts_in,
                                                                 cts_in_avals)
            else:
                cts_out = get_primitive_transpose(eqn.primitive)(cts_in,
                                                                 *invals,
                                                                 **eqn.params)
        cts_out = [Zero(v.aval)
                   for v in eqn.invars] if cts_out is Zero else cts_out
        # FIXME: Some invars correspond to primals!
        map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)

    cotangents_out = map(read_cotangent, jaxpr.invars)
    return cotangents_out
コード例 #4
0
ファイル: ad.py プロジェクト: jbampton/jax
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
                  consts, primals_in, cotangents_in):
  if all(type(ct) is Zero for ct in cotangents_in):
    return map(lambda v: Zero(v.aval), jaxpr.invars)

  def write_cotangent(prim, v, ct):
    # assert v not in primal_env
    assert ct is not Zero, (prim, v.aval)  # check for an old harmless type error
    if ct is None or type(v) is Literal:
      return
    if type(ct) is Zero:
      # FIXME: This triggers a lot of failures!
      # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
      return
    axes_to_reduce = tuple(axis_name for axis_name in reduce_axes
                           if axis_name in core.get_aval(ct).named_shape
                           and axis_name not in v.aval.named_shape)
    if axes_to_reduce:
      ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
    ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
    if config.jax_enable_checks:
      ct_aval = core.get_aval(ct_env[v])
      joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
      assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)

  def read_cotangent(v):
    return ct_env.pop(v, Zero(v.aval))

  def read_primal(v):
    if type(v) is Literal:
      return v.val
    else:
      return primal_env.get(v, UndefinedPrimal(v.aval))

  def write_primal(v, val):
    if not is_undefined_primal(val):
      primal_env[v] = val

  primal_env: Dict[Any, Any] = {}
  write_primal(core.unitvar, core.unit)
  map(write_primal, jaxpr.constvars, consts)
  # FIXME: invars can contain both primal and tangent values, and this line
  #        forces primal_in to contain UndefinedPrimals for tangent values!
  map(write_primal, jaxpr.invars, primals_in)

  ct_env: Dict[Any, Any] = {}
  ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
         else contextlib.nullcontext())
  with ctx:
    map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
    for eqn in jaxpr.eqns[::-1]:
      # FIXME: Some invars correspond to tangents
      invals = map(read_primal, eqn.invars)
      if eqn.primitive.multiple_results:
        cts_in = map(read_cotangent, eqn.outvars)
      else:
        cts_in, = map(read_cotangent, eqn.outvars)
      name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
      with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
        if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
          cts_in_avals = [v.aval for v in eqn.outvars]
          params = dict(eqn.params)
          call_jaxpr = params.pop('call_jaxpr')
          cts_out = get_primitive_transpose(eqn.primitive)(
              params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
        elif eqn.primitive in reducing_transposes:
          cts_out = reducing_transposes[eqn.primitive](
              reduce_axes, cts_in, *invals, **eqn.params)
        else:
          cts_out = get_primitive_transpose(eqn.primitive)(
              cts_in, *invals, **eqn.params)
        cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
        # FIXME: Some invars correspond to primals!
        map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)

  cotangents_out = map(read_cotangent, jaxpr.invars)
  return cotangents_out
コード例 #5
0
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
                  consts: Sequence[Sequence[ir.Value]],
                  *args: Sequence[ir.Value]) -> Sequence[Sequence[ir.Value]]:
    """Lowers a jaxpr into mHLO, inlined into an existing function.

  Assumes that an MLIR context, location, and insertion point are set.
  """
    def read(v: core.Var) -> Sequence[ir.Value]:
        if type(v) is core.Literal:
            return ir_constants(v.val, canonicalize_types=True)
        else:
            return env[v]

    def aval(v: core.Var) -> core.AbstractValue:
        if type(v) is core.Literal:
            return xla.abstractify(v.val)
        else:
            return v.aval

    def write(v: core.Var, node: Sequence[ir.Value]):
        assert node is not None
        env[v] = tuple(node)

    env: Dict[core.Var, Tuple[ir.Value, ...]] = {}

    assert len(args) == len(jaxpr.invars), (jaxpr, args)
    assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
    assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
    write(core.unitvar, ())
    map(write, jaxpr.constvars, consts)
    map(write, jaxpr.invars, args)
    for eqn in jaxpr.eqns:
        in_nodes = map(read, eqn.invars)
        loc = _source_info_to_location(eqn.primitive,
                                       eqn.params,
                                       eqn.source_info,
                                       name_stack=ctx.name_stack)
        with source_info_util.user_context(eqn.source_info.traceback), loc:
            if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
                rule = _platform_specific_lowerings[ctx.platform][
                    eqn.primitive]
            elif eqn.primitive in xla._backend_specific_translations[
                    ctx.platform]:
                rule = xla_fallback_lowering(eqn.primitive)
            elif eqn.primitive in _lowerings:
                rule = _lowerings[eqn.primitive]
            elif eqn.primitive in xla._translations:
                rule = xla_fallback_lowering(eqn.primitive)
            else:
                raise NotImplementedError(
                    f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
                    f"found for platform {ctx.platform}")

            rule_ctx = LoweringRuleContext(module_context=ctx,
                                           primitive=eqn.primitive,
                                           avals_in=map(aval, eqn.invars),
                                           avals_out=map(aval, eqn.outvars))
            ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
                       **eqn.params)

        try:
            out_nodes = tuple(map(wrap_singleton_ir_values, ans))
        except TypeError as e:
            raise ValueError("Output of translation rule must be iterable: "
                             f"{eqn}, got output {ans}") from e

        assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
        assert all(isinstance(v, ir.Value) for w in out_nodes
                   for v in w), (ans, eqn)
        assert len(ans) == len(eqn.outvars), (ans, eqn)
        map(write, eqn.outvars, out_nodes)
    return map(read, jaxpr.outvars)
コード例 #6
0
ファイル: mlir.py プロジェクト: rsepassi/jax
def jaxpr_subcomp(ctx: LoweringContext, jaxpr: core.Jaxpr,
                  consts: Sequence[Sequence[ir.Value]],
                  *args: Sequence[ir.Value]) -> Sequence[Sequence[ir.Value]]:
    """Lowers a jaxpr into mHLO, inlined into an existing function.

  Assumes that an MLIR context, location, and insertion point are set.
  """
    def read(v):
        if type(v) is core.Literal:
            return ir_constants(v.val, canonicalize_types=True)
        else:
            return env[v]

    def aval(v):
        if type(v) is core.Literal:
            return xla.abstractify(v.val)
        else:
            return v.aval

    def write(v, node):
        assert node is not None
        env[v] = tuple(node)

    env: Dict[core.Var, Tuple[ir.Value]] = {}

    assert len(args) == len(jaxpr.invars), (jaxpr, args)
    assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
    write(core.unitvar, ())
    map(write, jaxpr.constvars, consts)
    map(write, jaxpr.invars, args)
    for eqn in jaxpr.eqns:
        in_nodes = map(read, eqn.invars)
        # TODO(phawkins): attach the primitive name, parameters, and name stack as
        # metadata.
        loc = _source_info_to_location(eqn.source_info)
        with source_info_util.user_context(eqn.source_info.traceback), loc:
            if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
                rule = _platform_specific_lowerings[ctx.platform][
                    eqn.primitive]
            elif eqn.primitive in _lowerings:
                rule = _lowerings[eqn.primitive]
            elif eqn.primitive in xla._translations:
                rule = partial(xla_fallback_lowering, eqn.primitive)
            else:
                raise NotImplementedError(
                    f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
                    "found")

            ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
                       *map(_unwrap_singleton_ir_values, in_nodes),
                       **eqn.params)

        try:
            out_nodes = tuple(map(wrap_singleton_ir_values, ans))
        except TypeError as e:
            raise ValueError("Output of translation rule must be iterable: "
                             f"{eqn}") from e

        assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
        assert all(isinstance(v, ir.Value) for w in out_nodes
                   for v in w), (ans, eqn)
        assert len(ans) == len(eqn.outvars), (ans, eqn)
        map(write, eqn.outvars, out_nodes)
    return map(read, jaxpr.outvars)