def _transpose_papply_rule(name, vals, dims, permutation): x, = vals xdim, = dims perm = list(permutation) if perm[xdim] == xdim: x = lax.transpose(x, perm) out_dim = xdim else: in_dim, = [i for i in range(len(perm)) if perm[i] == xdim] out_dim = perm[xdim] perm[in_dim] = out_dim perm[out_dim] = in_dim perm = perm[:xdim] + perm[xdim + 1:] perm = [i - 1 if i > xdim else i for i in perm] x = lax.transpose(x, perm) x = pswapaxes(x, name, in_dim) return x, xdim
def _pswapaxes_serial_pmap_rule(vals, axes, axis): x, = vals axis_in, = axes if x.shape[axis_in] != x.shape[axis]: raise ValueError("pswapaxes between non-square dimensions") perm = list(range(x.ndim)) perm[axis_in] = axis perm[axis] = axis_in return lax.transpose(x, perm), axis_in
def _transpose_papply_rule(name, size, vals, dims, permutation): x, = vals xdim, = dims local_perm = [i if i < xdim else i - 1 for i in permutation if i != xdim] return lax.transpose(x, local_perm), permutation.index(xdim)
def _moveaxis(src, dst, x): perm = [i for i in range(x.ndim) if i != src] perm.insert(dst, src) return lax.transpose(x, perm)