예제 #1
0
def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
  axis_pos = list(axis_env.names).index(axis_name)
  nreplicas = axis_env.nreps // prod(axis_env.sizes)
  div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
                                dtype=np.uint32))
  mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
  unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
  return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
예제 #2
0
 def _axis_index_translation_rule(c, nreps, sizes, axis_name):
   div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
   mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
   unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
   return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))