def _update_arrays(i, aval, xs, x): assert isinstance(aval, core.AbstractValue) if isinstance(aval, core.AbstractTuple): return core.pack(map(partial(_update_arrays, i), aval, xs, x)) else: x = lax.reshape(x, (1,) + onp.shape(x)) return lax.dynamic_update_index_in_dim(xs, x, i, axis=0)
def _allgather(x, dim, size, axis_name): shape = list(x.shape) shape.insert(dim, size) out = lax.full(shape, lax._const(x, 0)) out = lax.dynamic_update_index_in_dim(out, x, axis_index(axis_name), dim) return psum(out, axis_name)
def _expand(dim, size, axis_name, x): shape = list(x.shape) shape.insert(dim, size) out = lax.full(shape, lax._const(x, 0)) return lax.dynamic_update_index_in_dim(out, x, axis_index(axis_name), dim)
def body(i, dst): update = lax.dynamic_index_in_dim(src, i, axis) return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis)
def _update_array(i, aval, xs, x): if aval is core.abstract_unit: return core.unit else: return lax.dynamic_update_index_in_dim(xs, x, i, 0)
def _update_arrays(i, aval, xs, x): assert isinstance(aval, core.AbstractValue) if isinstance(aval, core.AbstractTuple): return core.pack(map(partial(_update_arrays, i), aval, xs, x)) else: return lax.dynamic_update_index_in_dim(xs, x[None, ...], i, axis=0)