예제 #1
0
def _ppermute_batcher(vals_in, dims_in, axis_size, axis_name, perm):
  assert len(perm) == axis_size, "Permutation doesn't match the axis size!"
  perm_indices = np.full((axis_size,), -1, dtype=np.int32)
  for s, d in perm:
    perm_indices[s] = d
  vals_out = [lax_numpy.take(v, perm_indices, d) if d is not batching.not_mapped else v
              for v, d in zip(vals_in, dims_in)]
  return vals_out, dims_in
예제 #2
0
def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm):
    assert len(perm) == frame.size, "Permutation doesn't match the axis size!"
    assert axis_name == frame.name, "ppermute batcher called with wrong axis name"
    (v, ), (d, ) = vals_in, dims_in
    assert d is not batching.not_mapped
    perm_indices = [None] * frame.size
    for src, dst in perm:
        perm_indices[src] = dst
    return lax_numpy.take(v, perm_indices, d), d