def sharding_to_proto(sharding: SpatialSharding): """Converts a SpatialSharding to an OpSharding. See https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601 for details on the OpSharding proto. """ proto = xc.OpSharding() if isinstance(sharding, tuple) and not isinstance(sharding[0], int): assert all(s is None or isinstance(s, tuple) for s in sharding) return tuple_sharding_proto(list(map(sharding_to_proto, sharding))) # type: ignore if sharding is None: proto.type = xc.OpSharding.Type.REPLICATED else: proto.type = xc.OpSharding.Type.OTHER proto.tile_assignment_dimensions = list(sharding) proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore return proto
def tuple_sharding_proto(elems): proto = xc.OpSharding() assert all(isinstance(e, type(proto)) for e in elems) proto.type = xc.OpSharding.Type.TUPLE proto.tuple_shardings = elems return proto
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: proto = xc.OpSharding() proto.type = xc.OpSharding.Type.REPLICATED return proto