示例#1
0
文件: mlir.py 项目: ahoenselaar/jax
def wrap_with_sharding_op(x, sharding_proto: xc.OpSharding):
    op = mhlo.CustomCallOp([x.type], [x],
                           call_target_name=ir.StringAttr.get("Sharding"),
                           has_side_effect=ir.BoolAttr.get(False),
                           backend_config=ir.StringAttr.get(""),
                           api_version=i32_attr(1),
                           called_computations=ir.ArrayAttr.get([]),
                           operand_layouts=None,
                           result_layouts=None)
    op.attributes["mhlo.sharding"] = ir.StringAttr.get(
        sharding_proto.SerializeToString())
    return op.result
示例#2
0
def wrap_with_sharding_op(x,
                          sharding_proto: xc.OpSharding,
                          unspecified_dims: Optional[Set[int]] = None):
    # unspecified_dims indicate dimensions whose shardings are not specified and
    # XLA sharding propagation can change them.
    if unspecified_dims:
        backend_config = "unspecified_dims=[" + ",".join(
            [str(i) for i in sorted(unspecified_dims)]) + "]"
    else:
        backend_config = ""
    op = mhlo.CustomCallOp([x.type], [x],
                           call_target_name=ir.StringAttr.get("Sharding"),
                           has_side_effect=ir.BoolAttr.get(False),
                           backend_config=ir.StringAttr.get(backend_config),
                           api_version=i32_attr(1),
                           called_computations=ir.ArrayAttr.get([]),
                           operand_layouts=None,
                           result_layouts=None)
    op.attributes["mhlo.sharding"] = ir.StringAttr.get(
        sharding_proto.SerializeToString())
    return op.result