def wrap_with_sharding_op(x, sharding_proto: xc.OpSharding): op = mhlo.CustomCallOp([x.type], [x], call_target_name=ir.StringAttr.get("Sharding"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=i32_attr(1), called_computations=ir.ArrayAttr.get([]), operand_layouts=None, result_layouts=None) op.attributes["mhlo.sharding"] = ir.StringAttr.get( sharding_proto.SerializeToString()) return op.result
def wrap_with_sharding_op(x, sharding_proto: xc.OpSharding, unspecified_dims: Optional[Set[int]] = None): # unspecified_dims indicate dimensions whose shardings are not specified and # XLA sharding propagation can change them. if unspecified_dims: backend_config = "unspecified_dims=[" + ",".join( [str(i) for i in sorted(unspecified_dims)]) + "]" else: backend_config = "" op = mhlo.CustomCallOp([x.type], [x], call_target_name=ir.StringAttr.get("Sharding"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(backend_config), api_version=i32_attr(1), called_computations=ir.ArrayAttr.get([]), operand_layouts=None, result_layouts=None) op.attributes["mhlo.sharding"] = ir.StringAttr.get( sharding_proto.SerializeToString()) return op.result