Esempio n. 1
0
File: xla.py Progetto: John1Tang/jax
def sharding_to_proto(sharding: SpatialSharding):
  """Converts a SpatialSharding to an OpSharding.

  See
  https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
  for details on the OpSharding proto.
  """
  proto = xc.OpSharding()
  if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
    assert all(s is None or isinstance(s, tuple) for s in sharding)
    return tuple_sharding_proto(list(map(sharding_to_proto, sharding)))  # type: ignore

  if sharding is None:
    proto.type = xc.OpSharding.Type.REPLICATED
  else:
    proto.type = xc.OpSharding.Type.OTHER
    proto.tile_assignment_dimensions = list(sharding)
    proto.tile_assignment_devices = list(range(np.product(sharding)))  # type: ignore
  return proto
Esempio n. 2
0
def tuple_sharding_proto(elems):
    proto = xc.OpSharding()
    assert all(isinstance(e, type(proto)) for e in elems)
    proto.type = xc.OpSharding.Type.TUPLE
    proto.tuple_shardings = elems
    return proto
Esempio n. 3
0
 def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
   proto = xc.OpSharding()
   proto.type = xc.OpSharding.Type.REPLICATED
   return proto