Example #1
0
def _update_arrays(i, aval, xs, x):
  assert isinstance(aval, core.AbstractValue)
  if isinstance(aval, core.AbstractTuple):
    return core.pack(map(partial(_update_arrays, i), aval, xs, x))
  else:
    x = lax.reshape(x, (1,) + onp.shape(x))
    return lax.dynamic_update_index_in_dim(xs, x, i, axis=0)
Example #2
0
def _allgather(x, dim, size, axis_name):
    shape = list(x.shape)
    shape.insert(dim, size)
    out = lax.full(shape, lax._const(x, 0))
    out = lax.dynamic_update_index_in_dim(out, x, axis_index(axis_name), dim)
    return psum(out, axis_name)
Example #3
0
def _expand(dim, size, axis_name, x):
    shape = list(x.shape)
    shape.insert(dim, size)
    out = lax.full(shape, lax._const(x, 0))
    return lax.dynamic_update_index_in_dim(out, x, axis_index(axis_name), dim)
Example #4
0
 def body(i, dst):
   update = lax.dynamic_index_in_dim(src, i, axis)
   return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis)
Example #5
0
def _update_array(i, aval, xs, x):
  if aval is core.abstract_unit:
    return core.unit
  else:
    return lax.dynamic_update_index_in_dim(xs, x, i, 0)
Example #6
0
def _update_arrays(i, aval, xs, x):
    assert isinstance(aval, core.AbstractValue)
    if isinstance(aval, core.AbstractTuple):
        return core.pack(map(partial(_update_arrays, i), aval, xs, x))
    else:
        return lax.dynamic_update_index_in_dim(xs, x[None, ...], i, axis=0)