def main(): parser = argparse.ArgumentParser() # required parser.add_argument("--order_in", type=int, required=True) parser.add_argument("--order_out", type=int, required=True) parser.add_argument("--n", type=int, default=50, help="size of the SOFT grid") parser.add_argument("--scale", type=float, default=1.5, help="plot size of a sphere") parser.add_argument("--sep", type=float, default=1, help="plot separation size") parser.add_argument("--alpha", type=float, default=0) parser.add_argument("--beta", type=float, default=0) parser.add_argument("--gamma", type=float, default=0) args = parser.parse_args() f = _sample_sh_sphere(args.n, args.order_in, args.order_out, args.alpha, args.beta, args.gamma) # f(r^-1 x) f = np.einsum( "ij,zjkba,kl->zilba", irr_repr(args.order_out, args.alpha, args.beta, args.gamma), f, irr_repr(args.order_in, -args.gamma, -args.beta, -args.alpha)) # rho_out(r) f(r^-1 x) rho_in(r^-1) beta, alpha = beta_alpha(args.n) alpha = alpha - np.pi / (2 * args.n) nbase = f.shape[0] dim_out = f.shape[1] dim_in = f.shape[2] f = (f - np.min(f)) / (np.max(f) - np.min(f)) fig = plt.figure(figsize=(args.scale * (nbase * dim_in + (nbase - 1) * args.sep), args.scale * dim_out)) for base in range(nbase): for i in range(dim_out): for j in range(dim_in): width = 1 / (nbase * dim_in + (nbase - 1) * args.sep) height = 1 / dim_out rect = [(base * (dim_in + args.sep) + j) * width, (dim_out - i - 1) * height, width, height] ax = fig.add_axes(rect, projection='3d', aspect=1) ax.patch.set_visible(False) plot_sphere(beta, alpha, f[base, i, j]) plt.savefig("kernels{}{}.png".format(args.order_in, args.order_out), transparent=True)
def check_basis_equivariance(basis, order_in, order_out, alpha, beta, gamma): from se3cnn import SO3 from scipy.ndimage import affine_transform import numpy as np n = basis.size(0) dim_in = 2 * order_in + 1 dim_out = 2 * order_out + 1 size = basis.size(-1) assert basis.size() == (n, dim_out, dim_in, size, size, size), basis.size() basis = basis / basis.view(n, -1).norm(dim=1).view(-1, 1, 1, 1, 1, 1) x = basis.view(-1, size, size, size) y = torch.empty_like(x) invrot = SO3.rot(-gamma, -beta, -alpha).numpy() center = (np.array(x.size()[1:]) - 1) / 2 for k in range(y.size(0)): y[k] = torch.tensor( affine_transform(x[k].numpy(), matrix=invrot, offset=center - np.dot(invrot, center))) y = y.view(*basis.size()) y = torch.einsum( "ij,bjkxyz,kl->bilxyz", (irr_repr(order_out, alpha, beta, gamma, dtype=y.dtype), y, irr_repr(order_in, -gamma, -beta, -alpha, dtype=y.dtype))) return torch.tensor([(basis[i] * y[i]).sum() for i in range(n)])
def cube_basis_kernels(size, order_in, order_out, radial_window): ''' Generate equivariant kernel basis mapping between capsules transforming under order_in and order_out :param size: side length of the filter kernel :param order_in: input representation order :param order_out: output representation order :param radial_window: callable for windowing out radial parts, taking mandatory parameters 'solutions', 'r_field' and 'order_irreps' :return: basis of equivariant kernels of shape (N_basis, 2 * order_out + 1, 2 * order_in + 1, size, size, size) ''' basis = radial_window(*_sample_cube(size, order_in, order_out)) if basis is None: return None # check that rho_out(u) K(u^-1 x) rho_in(u^-1) = K(x) with u = rotation of +pi/2 around y axis tmp = basis.transpose(3, 5).flip(5) # K(u^-1 x) tmp = torch.einsum( "ij,bjkxyz,kl->bilxyz", ( irr_repr(order_out, 0, math.pi / 2, 0, dtype=basis.dtype), tmp, irr_repr(order_in, 0, -math.pi / 2, 0, dtype=basis.dtype) ) ) # rho_out(u) K(u^-1 x) rho_in(u^-1) assert torch.allclose(tmp, basis) return basis
def _sample_cube(size, order_in, order_out): ''' :param size: side length of the kernel :param order_in: order of the input representation :param order_out: order of the output representation :return: sampled equivariant kernel basis of shape (N_basis, 2*order_out+1, 2*order_in+1, size, size, size) ''' rng = torch.linspace(-((size - 1) / 2), (size - 1) / 2, steps=size, dtype=torch.float64) order_irreps = list(range(abs(order_in - order_out), order_in + order_out + 1)) solutions = [] for J in order_irreps: Y_J = _sample_sh_cube(size, J) # [m, x, y, z] # compute basis transformation matrix Q_J Q_J = _basis_transformation_Q_J(J, order_in, order_out) # [m_out * m_in, m] K_J = torch.einsum('mn,nxyz->mxyz', (Q_J, Y_J)) # [m_out * m_in, x, y, z] K_J = K_J.view(2 * order_out + 1, 2 * order_in + 1, size, size, size) # [m_out, m_in, x, y, z] solutions.append(K_J) # check that rho_out(u) K(u^-1 x) rho_in(u^-1) = K(x) with u = rotation of +pi/2 around y axis tmp = K_J.transpose(2, 4).flip(4) # K(u^-1 x) tmp = torch.einsum( "ij,jkxyz,kl->ilxyz", ( irr_repr(order_out, 0, math.pi / 2, 0, dtype=K_J.dtype), tmp, irr_repr(order_in, 0, -math.pi / 2, 0, dtype=K_J.dtype) ) ) # rho_out(u) K(u^-1 x) rho_in(u^-1) assert torch.allclose(tmp, K_J) r_field = (rng.view(-1, 1, 1).pow(2) + rng.view(1, -1, 1).pow(2) + rng.view(1, 1, -1).pow(2)).sqrt() # [x, y, z] return solutions, r_field, order_irreps
def _basis_transformation_Q_J(J, order_in, order_out, version=3): # pylint: disable=W0613 """ :param J: order of the spherical harmonics :param order_in: order of the input representation :param order_out: order of the output representation :return: one part of the Q^-1 matrix of the article """ with torch_default_dtype(torch.float64): def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)) def _sylvester_submatrix(J, a, b, c): ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J ''' R_tensor = _R_tensor(a, b, c) # [m_out * m_in, m_out * m_in] R_irrep_J = irr_repr(J, a, b, c) # [m, m] return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \ kron(torch.eye(R_tensor.size(0)), R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m] random_angles = [ [4.41301023, 5.56684102, 4.59384642], [4.93325116, 6.12697327, 4.14574096], [0.53878964, 4.09050444, 5.36539036], [2.16017393, 3.48835314, 5.55174441], [2.52385107, 0.2908958, 3.90040975] ] null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles]) assert null_space.size(0) == 1, null_space.size() # unique subspace solution Q_J = null_space[0] # [(m_out * m_in) * m] Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1) # [m_out * m_in, m] assert all(torch.allclose( _R_tensor(a.item(), b.item(), c.item()) @ Q_J, Q_J @ irr_repr(J, a.item(), b.item(), c.item())) for a, b, c in torch.rand(4, 3) ) assert Q_J.dtype == torch.float64 return Q_J # [m_out * m_in, m]
def transform(x): convert = Obj2OrientedVoxel(32) x, abc = convert(x) x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0) rot = irr_repr(1, *abc) @ xyz_vector_basis_to_spherical_basis() top = rot @ torch.tensor([0, 0, 1.]) front = rot @ torch.tensor([0, 1., 0]) return x, (top, front)
def rotate(x, alpha, beta, gamma): t = time_logging.start() y = x.cpu().detach().numpy() R = rot(alpha, beta, gamma) if x.ndimension() == 4: for i in range(y.shape[0]): y[i] = rotate_scalar(y[i], R) else: y = rotate_scalar(y, R) x = x.new_tensor(y) if x.ndimension() == 4 and x.size(0) == 3: rep = irr_repr(1, alpha, beta, gamma, x.dtype).to(x.device) x = torch.einsum("ij,jxyz->ixyz", (rep, x)) time_logging.end("rotate", t) return x
def _sylvester_submatrix(J, a, b, c): ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J ''' R_tensor = _R_tensor(a, b, c) # [m_out * m_in, m_out * m_in] R_irrep_J = irr_repr(J, a, b, c) # [m, m] return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \ kron(torch.eye(R_tensor.size(0)), R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m]
def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))
def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)) def _sylvester_submatrix(J, a, b, c):