Beispiel #1
0
def _all_to_all_via_all_gather(x, *, axis_name, split_axis, concat_axis):
    global_full = all_gather(x, axis_name)
    idx = axis_index(axis_name)
    local_slice = lax.dynamic_index_in_dim(global_full,
                                           idx,
                                           split_axis + 1,
                                           keepdims=False)
    return _moveaxis(0, concat_axis, local_slice)
Beispiel #2
0
def _drop(x, dim, axis_name):
    return lax.dynamic_index_in_dim(x, axis_index(axis_name), dim, False)
Beispiel #3
0
 def body(i, dst):
   update = lax.dynamic_index_in_dim(src, i, axis)
   return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis)
Beispiel #4
0
def _index_array(i, aval, x):
  if aval is core.abstract_unit:
    return core.unit
  else:
    return lax.dynamic_index_in_dim(x, i, keepdims=False)
Beispiel #5
0
def _index_arrays(i, aval, xs):
  assert isinstance(aval, core.AbstractValue)
  if isinstance(aval, core.AbstractTuple):
    return core.pack(map(partial(_index_arrays, i), aval, xs))
  else:
    return lax.dynamic_index_in_dim(xs, i, keepdims=False)