def elementwise_tensor_product(Rs_1, Rs_2, get_l_output=o3.selection_rule): """ :return: Rs_out, matrix m_kij A_i B_j """ Rs_1 = simplify(Rs_1) Rs_2 = simplify(Rs_2) assert sum(mul for mul, _, _ in Rs_1) == sum(mul for mul, _, _ in Rs_2) i = 0 while i < len(Rs_1): mul_1, l_1, p_1 = Rs_1[i] mul_2, l_2, p_2 = Rs_2[i] if mul_1 < mul_2: Rs_2[i] = (mul_1, l_2, p_2) Rs_2.insert(i + 1, (mul_2 - mul_1, l_2, p_2)) if mul_2 < mul_1: Rs_1[i] = (mul_2, l_1, p_1) Rs_1.insert(i + 1, (mul_1 - mul_2, l_1, p_1)) i += 1 Rs_out = [] for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_1, Rs_2): assert mul == mul_2 for l in get_l_output(l_1, l_2): Rs_out.append((mul, l, p_1 * p_2)) Rs_out = simplify(Rs_out) clebsch_gordan_tensor = torch.zeros(dim(Rs_out), dim(Rs_1), dim(Rs_2)) index_out = 0 index_1 = 0 index_2 = 0 for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_1, Rs_2): assert mul == mul_2 dim_1 = mul * (2 * l_1 + 1) dim_2 = mul * (2 * l_2 + 1) for l in get_l_output(l_1, l_2): dim_out = mul * (2 * l + 1) C = o3.clebsch_gordan(l, l_1, l_2, cached=True) * (2 * l + 1)**0.5 I = torch.einsum("uv,wu->wuv", torch.eye(mul), torch.eye(mul)) m = torch.einsum("wuv,kij->wkuivj", I, C).view(dim_out, dim_1, dim_2) clebsch_gordan_tensor[index_out:index_out + dim_out, index_1:index_1 + dim_1, index_2:index_2 + dim_2] = m index_out += dim_out index_1 += dim_1 index_2 += dim_2 return Rs_out, clebsch_gordan_tensor
def tensor_product(Rs_1, Rs_2, get_l_output=o3.selection_rule, paths=False): """ Compute the orthonormal change of basis Q from Rs_out to Rs_1 tensor product with Rs_2 where Rs_out is a direct sum of irreducible representations :return: Rs_out, Q example: _, Q = tensor_product(Rs1, Rs2) torch.einsum('kij,i,j->k', Q, A, B) """ Rs_1 = simplify(Rs_1) Rs_2 = simplify(Rs_2) Rs_out = [] if paths: path_list = [] for mul_1, l_1, p_1 in Rs_1: for mul_2, l_2, p_2 in Rs_2: for l in get_l_output(l_1, l_2): if paths: path_list.extend([[l_1, l_2, l]] * (mul_1 * mul_2)) Rs_out.append((mul_1 * mul_2, l, p_1 * p_2)) Rs_out = simplify(Rs_out) clebsch_gordan_tensor = torch.zeros(dim(Rs_out), dim(Rs_1), dim(Rs_2)) index_out = 0 index_1 = 0 for mul_1, l_1, _p_1 in Rs_1: dim_1 = mul_1 * (2 * l_1 + 1) index_2 = 0 for mul_2, l_2, _p_2 in Rs_2: dim_2 = mul_2 * (2 * l_2 + 1) for l in get_l_output(l_1, l_2): dim_out = mul_1 * mul_2 * (2 * l + 1) C = o3.clebsch_gordan(l, l_1, l_2, cached=True) * (2 * l + 1)**0.5 I = torch.eye(mul_1 * mul_2).view(mul_1 * mul_2, mul_1, mul_2) m = torch.einsum("wuv,kij->wkuivj", I, C).view(dim_out, dim_1, dim_2) clebsch_gordan_tensor[index_out:index_out + dim_out, index_1:index_1 + dim_1, index_2:index_2 + dim_2] = m index_out += dim_out index_2 += dim_2 index_1 += dim_1 if paths: return Rs_out, clebsch_gordan_tensor, path_list return Rs_out, clebsch_gordan_tensor
def test_clebsch_gordan_orthogonal(self): with o3.torch_default_dtype(torch.float64): for l_out in range(6): for l_in in range(6): for l_f in range(abs(l_out - l_in), l_out + l_in + 1): Q = o3.clebsch_gordan(l_f, l_in, l_out).view(2 * l_f + 1, -1) e = (2 * l_f + 1) * Q @ Q.t() d = e - torch.eye(2 * l_f + 1) self.assertLess(d.pow(2).mean().sqrt(), 1e-10)
def test_clebsch_gordan_sh_norm(self): with o3.torch_default_dtype(torch.float64): for l_out in range(6): for l_in in range(6): for l_f in range(abs(l_out - l_in), l_out + l_in + 1): Q = o3.clebsch_gordan(l_out, l_in, l_f) Y = o3.spherical_harmonics_xyz(l_f, torch.randn( 1, 3)).view(2 * l_f + 1) QY = math.sqrt(4 * math.pi) * Q @ Y self.assertLess(abs(QY.norm() - 1), 1e-10)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--l_in", type=int, required=True) parser.add_argument("--l_out", type=int, required=True) parser.add_argument("--n", type=int, default=30, help="size of the SOFT grid") parser.add_argument("--dpi", type=float, default=100) parser.add_argument("--sep", type=float, default=0.5, help="space between matrices") args = parser.parse_args() torch.set_default_dtype(torch.float64) x, y, z, alpha, beta = spherical_surface(args.n) out = [] for l in range(abs(args.l_out - args.l_in), args.l_out + args.l_in + 1): C = o3.clebsch_gordan(args.l_out, args.l_in, l) Y = o3.spherical_harmonics(l, alpha, beta) out.append(torch.einsum("ijk,k...->ij...", (C, Y))) f = torch.stack(out) nf, dim_out, dim_in, *_ = f.size() f = 0.5 + 0.5 * f / f.abs().max() fig = plt.figure(figsize=(nf * dim_in + (nf - 1) * args.sep, dim_out), dpi=args.dpi) for index in range(nf): for i in range(dim_out): for j in range(dim_in): width = 1 / (nf * dim_in + (nf - 1) * args.sep) height = 1 / dim_out rect = [ (index * (dim_in + args.sep) + j) * width, (dim_out - i - 1) * height, width, height ] ax = fig.add_axes(rect, projection='3d') fc = plt.get_cmap("bwr")(f[index, i, j].detach().cpu().numpy()) ax.plot_surface(x.numpy(), y.numpy(), z.numpy(), rstride=1, cstride=1, facecolors=fc) ax.set_axis_off() a = 0.6 ax.set_xlim3d(-a, a) ax.set_ylim3d(-a, a) ax.set_zlim3d(-a, a) ax.view_init(90, 0) plt.savefig("kernels{}to{}.png".format(args.l_in, args.l_out), transparent=True)
def test1(self): """Test irr_repr and clebsch_gordan equivariance.""" with torch_default_dtype(torch.float64): l_in = 3 l_out = 2 for l_f in range(abs(l_in - l_out), l_in + l_out + 1): r = torch.randn(100, 3) Q = o3.clebsch_gordan(l_out, l_in, l_f) abc = torch.randn(3) D_in = o3.irr_repr(l_in, *abc) D_out = o3.irr_repr(l_out, *abc) Y = o3.spherical_harmonics_xyz(l_f, r @ o3.rot(*abc).t()) W = torch.einsum("ijk,kz->zij", (Q, Y)) W1 = torch.einsum("zij,jk->zik", (W, D_in)) Y = o3.spherical_harmonics_xyz(l_f, r) W = torch.einsum("ijk,kz->zij", (Q, Y)) W2 = torch.einsum("ij,zjk->zik", (D_out, W)) self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
def kernel_conv_fn_forward(F, Y, R, norm_coef, Rs_in, Rs_out, get_l_filters, set_of_l_filters): """ :param F: tensor [batch, b, l_in * mul_in * m_in] :param Y: tensor [l_filter * m_filter, batch, a, b] :param R: tensor [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter] :param norm_coef: tensor [l_out, l_in, batch, a, b] :return: tensor [batch, a, l_out * mul_out * m_out, l_in * mul_in * m_in] """ batch, a, b = Y.shape[1:] n_in = rs.dim(Rs_in) n_out = rs.dim(Rs_out) kernel_conv = Y.new_zeros(batch, a, n_out) # note: for the normalization we assume that the variance of R[i] is one begin_R = 0 begin_out = 0 for i, (mul_out, l_out, p_out) in enumerate(Rs_out): s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1)) begin_out += mul_out * (2 * l_out + 1) begin_in = 0 for j, (mul_in, l_in, p_in) in enumerate(Rs_in): s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1)) begin_in += mul_in * (2 * l_in + 1) l_filters = get_l_filters(l_in, p_in, l_out, p_out) if not l_filters: continue # extract the subset of the `R` that corresponds to the couple (l_out, l_in) n = mul_out * mul_in * len(l_filters) sub_R = R[:, :, :, begin_R:begin_R + n].contiguous().view( batch, a, b, mul_out, mul_in, -1) # [batch, a, b, mul_out, mul_in, l_filter] begin_R += n sub_norm_coef = norm_coef[i, j] # [batch] K = 0 for k, l_filter in enumerate(l_filters): offset = sum(2 * l + 1 for l in set_of_l_filters if l < l_filter) sub_Y = Y[offset:offset + 2 * l_filter + 1, ...] # [m, batch, a, b] C = o3.clebsch_gordan(l_out, l_in, l_filter, cached=True, like=kernel_conv) # [m_out, m_in, m] K += torch.einsum("ijk,kzab,zabuv,zab,zbvj->zaui", C, sub_Y, sub_R[..., k], sub_norm_coef, F[..., s_in].view( batch, b, mul_in, -1)) # [batch, a, mul_out, m_out] if K is not 0: kernel_conv[:, :, s_out] += K.view(batch, a, -1) return kernel_conv
def backward(ctx, grad_kernel): F, Y, R, norm_coef = ctx.saved_tensors batch, a, b = ctx.batch, ctx.a, ctx.b grad_F = grad_Y = grad_R = None if ctx.needs_input_grad[0]: grad_F = grad_kernel.new_zeros( *ctx.F_shape) # [batch, b, l_in * mul_in * m_in] if ctx.needs_input_grad[1]: grad_Y = grad_kernel.new_zeros( *ctx.Y_shape) # [l_filter * m_filter, batch, a, b] if ctx.needs_input_grad[2]: grad_R = grad_kernel.new_zeros( *ctx.R_shape ) # [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter] begin_R = 0 begin_out = 0 for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out): s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1)) begin_out += mul_out * (2 * l_out + 1) begin_in = 0 for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in): s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1)) begin_in += mul_in * (2 * l_in + 1) l_filters = ctx.get_l_filters(l_in, p_in, l_out, p_out) if not l_filters: continue n = mul_out * mul_in * len(l_filters) if (grad_Y is not None) or (grad_F is not None): sub_R = R[:, :, :, begin_R:begin_R + n].contiguous().view( batch, a, b, mul_out, mul_in, -1) # [batch, a, b, mul_out, mul_in, l_filter] if grad_R is not None: sub_grad_R = grad_R[:, :, :, begin_R:begin_R + n].contiguous( ).view(batch, a, b, mul_out, mul_in, -1) # [batch, a, b, mul_out, mul_in, l_filter] if grad_F is not None: sub_grad_F = grad_F[:, :, s_in].contiguous().view( batch, b, mul_in, 2 * l_in + 1) # [batch, b, mul_in, 2 * l_in + 1] if (grad_Y is not None) or (grad_R is not None): sub_F = F[..., s_in].view(batch, b, mul_in, 2 * l_in + 1) grad_K = grad_kernel[:, :, s_out].view(batch, a, mul_out, 2 * l_out + 1) sub_norm_coef = norm_coef[i, j] # [batch, a, b] for k, l_filter in enumerate(l_filters): tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters if l < l_filter) C = o3.clebsch_gordan(l_out, l_in, l_filter, cached=True, like=grad_kernel) # [m_out, m_in, m] if (grad_F is not None) or (grad_R is not None): sub_Y = Y[tmp:tmp + 2 * l_filter + 1, ...] # [m, batch, a, b] if grad_F is not None: sub_grad_F += torch.einsum( "zaui,ijk,kzab,zabuv,zab->zbvj", grad_K, C, sub_Y, sub_R[..., k], sub_norm_coef) # [batch, b, mul_in, 2 * l_in + 1 if grad_Y is not None: grad_Y[tmp:tmp + 2 * l_filter + 1, ...] += torch.einsum( "zaui,ijk,zabuv,zab,zbvj->kzab", grad_K, C, sub_R[..., k], sub_norm_coef, sub_F) # [m, batch, a, b] if grad_R is not None: sub_grad_R[..., k] = torch.einsum( "zaui,ijk,kzab,zab,zbvj->zabuv", grad_K, C, sub_Y, sub_norm_coef, sub_F) # [batch, a, b, mul_out, mul_in] if grad_F is not None: grad_F[:, :, s_in] = sub_grad_F.view(batch, b, mul_in * (2 * l_in + 1)) if grad_R is not None: grad_R[..., begin_R:begin_R + n] += sub_grad_R.view( batch, a, b, -1) begin_R += n return grad_F, grad_Y, grad_R, None, None, None, None, None