コード例 #1
0
ファイル: mlir.py プロジェクト: rsepassi/jax
def _dummy_like_aval(aval: core.AbstractValue) -> Sequence[ir.Value]:
    if isinstance(aval, core.ShapedArray):
        return [full_like_aval(0, aval)]
    elif isinstance(aval, core.AbstractToken):
        return mhlo.CreateTokenOp(aval_to_ir_type(aval)).results
    elif isinstance(aval, core.AbstractUnit):
        return ()
    else:
        raise TypeError(f"Unsupported abstract value {aval}")
コード例 #2
0
def lower_jaxpr_to_fun(
    ctx: ModuleContext,
    name: str,
    jaxpr: core.ClosedJaxpr,
    *,
    public: bool = False,
    replace_units_with_dummy: bool = False,
    replace_tokens_with_dummy: bool = False,
    replicated_args: Optional[Sequence[bool]] = None,
    arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    use_sharding_annotations: bool = True,
    input_output_aliases: Optional[Sequence[Optional[int]]] = None
) -> FuncOpType:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
    replicated_args: if present, annotates arguments as replicated.
    arg_shardings: sharding annotations for each argument (optional).
    result_shardings: sharding annotations for each argument (optional).
    use_sharding_annotations: if True, use mhlo.sharding annotations on
      parameters and return values to express sharding. If False, use
      mhlo.custom_call operators with sharding annotations.
      TODO(b/228598865): remove this option when mhlo.sharding annotations are
      propagated on non-entry functions during MHLO->HLO conversion.
    input_output_aliases: optional sequence that maps argument numbers to the
      corresponding output that should alias them.
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    ctx.symbol_table.insert(func_op)
    ir_arg_shardings = None
    if arg_shardings is not None:
        ir_arg_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(arg_shardings, input_types)])
    ir_result_shardings = None
    if result_shardings is not None:
        ir_result_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(result_shardings, output_types)])

    if (replicated_args is not None or ir_arg_shardings is not None
            or input_output_aliases is not None):
        arg_attrs: List[Dict[str, ir.Attribute]] = [
            {} for _ in range(len(flat_input_types))
        ]

        if replicated_args is not None:
            replicated_ir_args = [
                [replicated] * len(types)
                for replicated, types in zip(replicated_args, input_types)
            ]
            for attrs, replicated in zip(arg_attrs,
                                         util.flatten(replicated_ir_args)):
                if replicated:
                    attrs[
                        "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get(
                        )

        if use_sharding_annotations and ir_arg_shardings is not None:
            for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
                if sharding is not None:
                    attrs["mhlo.sharding"] = ir.StringAttr.get(
                        sharding.SerializeToString())

        if input_output_aliases is not None:
            output_ids = util.unflatten(list(range(len(flat_output_types))),
                                        map(len, output_types))
            aliases: List[Optional[int]] = []
            for types, alias in zip(input_types, input_output_aliases):
                if alias is None:
                    aliases.extend([None] * len(types))
                else:
                    aliases.extend(output_ids[alias])

            for attrs, alias in zip(arg_attrs, aliases):
                if alias is not None:
                    attrs["tf.aliasing_output"] = i32_attr(alias)

        func_op.arg_attrs = ir.ArrayAttr.get(
            [ir.DictAttr.get(attrs) for attrs in arg_attrs])

    if use_sharding_annotations and ir_result_shardings is not None:
        func_op.result_attrs = ir.ArrayAttr.get([
            ir.DictAttr.get({} if sharding is None else {
                "mhlo.sharding":
                ir.StringAttr.get(sharding.SerializeToString())
            }) for sharding in ir_result_shardings
        ])

    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        flat_args = entry_block.arguments
        if not use_sharding_annotations and ir_arg_shardings is not None:
            flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings)

        unflattened_args = util.unflatten(flat_args, map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        flat_outputs = util.flatten(outs)
        if not use_sharding_annotations and ir_result_shardings is not None:
            flat_outputs = map(wrap_with_sharding_op, flat_outputs,
                               ir_result_shardings)

        func_dialect.ReturnOp(flat_outputs)

    return func_op
コード例 #3
0
for ptype, dtype in dtypes.python_scalar_dtypes.items():
    register_constant_handler(ptype, partial(_python_scalar_handler, dtype))


def _device_array_constant_handler(val, canonicalize_types):
    return _ndarray_constant_handler(val.device_buffer.to_py(),
                                     canonicalize_types)


for t in device_array.device_array_types:
    register_constant_handler(t, _device_array_constant_handler)

register_constant_handler(
    core.Token,
    lambda _, __: [mhlo.CreateTokenOp(mhlo.TokenType.get()).result])

# Source locations


def _source_info_to_location(
        primitive: core.Primitive,
        params: Dict,
        source_info: source_info_util.SourceInfo,
        name_stack: Union[str,
                          source_info_util.NameStack] = "") -> ir.Location:
    eqn_str = str(name_stack) + core.str_eqn_compact(primitive.name, params)
    frame = source_info_util.user_frame(source_info)
    if frame is None:
        loc = ir.Location.unknown()
    else:
コード例 #4
0
ファイル: mlir.py プロジェクト: rsepassi/jax
def lower_jaxpr_to_fun(ctx: LoweringContext,
                       name: str,
                       jaxpr: core.ClosedJaxpr,
                       *,
                       public: bool = False,
                       replace_units_with_dummy: bool = False,
                       replace_tokens_with_dummy: bool = False) -> str:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = builtin.FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value
    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        unflattened_args = util.unflatten(entry_block.arguments,
                                          map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        std.ReturnOp(util.flatten(outs))

    return symbol_name
コード例 #5
0
ファイル: mlir.py プロジェクト: GJBoth/jax
  register_constant_handler(_scalar_type, _ndarray_constant_handler)

def _python_scalar_handler(dtype, val, canonicalize_dtypes):
  return _numpy_array_constant(np.array(val, dtype), canonicalize_dtypes)

for ptype, dtype in dtypes.python_scalar_dtypes.items():
  register_constant_handler(ptype, partial(_python_scalar_handler, dtype))

def _device_array_constant_handler(val, canonicalize_types):
  return _ndarray_constant_handler(val.device_buffer.to_py(),
                                   canonicalize_types)
for t in device_array.device_array_types:
  register_constant_handler(t, _device_array_constant_handler)

register_constant_handler(
  core.Token, lambda _, __: [mhlo.CreateTokenOp(mhlo.TokenType.get())])

# Source locations

def _source_info_to_location(
    primitive: core.Primitive, params: Dict,
    source_info: source_info_util.SourceInfo,
    name_stack: str = "") -> ir.Location:
  eqn_str = name_stack + core.str_eqn_compact(primitive.name, params)
  frame = source_info_util.user_frame(source_info)
  if frame is None:
    loc = ir.Location.unknown()
  else:
    loc = ir.Location.file(xla._get_canonical_source_file(frame),
                           frame.line_num, 1)
  loc = ir.Location.name(eqn_str, childLoc=loc)