示例#1
0
def _transpose_papply_rule(name, vals, dims, permutation):
  x, = vals
  xdim, = dims
  perm = list(permutation)
  if perm[xdim] == xdim:
    x = lax.transpose(x, perm)
    out_dim = xdim
  else:
    in_dim, = [i for i in range(len(perm)) if perm[i] == xdim]
    out_dim = perm[xdim]
    perm[in_dim] = out_dim
    perm[out_dim] = in_dim
    perm = perm[:xdim] + perm[xdim + 1:]
    perm = [i - 1 if i > xdim else i for i in perm]
    x = lax.transpose(x, perm)
    x = pswapaxes(x, name, in_dim)
  return x, xdim
示例#2
0
def _pswapaxes_serial_pmap_rule(vals, axes, axis):
  x, = vals
  axis_in, = axes
  if x.shape[axis_in] != x.shape[axis]:
    raise ValueError("pswapaxes between non-square dimensions")
  perm = list(range(x.ndim))
  perm[axis_in] = axis
  perm[axis] = axis_in
  return lax.transpose(x, perm), axis_in
示例#3
0
def _transpose_papply_rule(name, size, vals, dims, permutation):
    x, = vals
    xdim, = dims
    local_perm = [i if i < xdim else i - 1 for i in permutation if i != xdim]
    return lax.transpose(x, local_perm), permutation.index(xdim)
示例#4
0
def _moveaxis(src, dst, x):
    perm = [i for i in range(x.ndim) if i != src]
    perm.insert(dst, src)
    return lax.transpose(x, perm)