def forward(self, features_1, features_2, weight=None): """ :return: tensor [..., channel] """ *size, n = features_1.size() features_1 = features_1.reshape(-1, n) assert n == rs.dim(self.Rs_in1), f"{n} is not {rs.dim(self.Rs_in1)}" *size2, n = features_2.size() features_2 = features_2.reshape(-1, n) assert n == rs.dim(self.Rs_in2), f"{n} is not {rs.dim(self.Rs_in2)}" assert size == size2 if weight is None: weight = self.weight weight = weight.reshape(-1, self.nweight) if weight.shape[0] == 1: weight = weight.repeat(features_1.shape[0], 1) wigners = [getattr(self, arg) for arg in self.wigners_names] if features_1.shape[0] == 0: return torch.zeros(*size, rs.dim(self.Rs_out)) features = self.main(*wigners, features_1, features_2, weight) return features.reshape(*size, -1)
def forward(self): """ :return: tensor [..., l_out * mul_out * m_out, l_in * mul_in * m_in] """ # (1) Case r > 0 # use the radial model to fix all the degrees of freedom # note: for the normalization we assume that the variance of R[i] is one R = self.R(self.radii[self.radii > self.r_eps] ) # [batch, l_out * l_in * mul_out * mul_in * l_filter] RY = rsh.mul_radial_angular(self.Rs_f, R, self.Y) if R.shape[0] == 0: kernel1 = torch.zeros(0, rs.dim(self.Rs_out), rs.dim(self.Rs_in)) else: kernel1 = self.tp.right(RY) # (2) Case r = 0 if self.linear is not None: kernel2 = self.linear() kernel = kernel1.new_zeros(len(self.radii), *kernel2.shape) kernel[self.radii > self.r_eps] = kernel1 kernel[self.radii <= self.r_eps] = kernel2 else: kernel = kernel1 return kernel.reshape(*self.size, *kernel1.shape[1:])
def forward(self): """ :return: tensor [l_out * mul_out * m_out, l_in * mul_in * m_in] """ kernel = self.weight.new_zeros(rs.dim(self.Rs_out), rs.dim(self.Rs_in)) begin_w = 0 begin_out = 0 for mul_out, l_out, p_out in self.Rs_out: s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1)) begin_out += mul_out * (2 * l_out + 1) n_path = 0 begin_in = 0 for mul_in, l_in, p_in in self.Rs_in: s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1)) begin_in += mul_in * (2 * l_in + 1) if (l_out, p_out) == (l_in, p_in): weight = self.weight[begin_w: begin_w + mul_out * mul_in].reshape(mul_out, mul_in) # [mul_out, mul_in] begin_w += mul_out * mul_in eye = torch.eye(2 * l_in + 1, dtype=self.weight.dtype, device=self.weight.device) kernel[s_out, s_in] = torch.einsum('uv,ij->uivj', weight, eye).reshape(mul_out * (2 * l_out + 1), mul_in * (2 * l_in + 1)) n_path += mul_in if n_path > 0: kernel[s_out] /= math.sqrt(n_path) return kernel
def kernel_fn_forward(Y, R, norm_coef, Rs_in, Rs_out, selection_rule, set_of_l_filters): """ :param Y: tensor [batch, l_filter * m_filter] :param R: tensor [batch, l_out * l_in * mul_out * mul_in * l_filter] :param norm_coef: tensor [l_out, l_in] :return: tensor [batch, l_out * mul_out * m_out, l_in * mul_in * m_in] """ batch = Y.shape[0] n_in = rs.dim(Rs_in) n_out = rs.dim(Rs_out) kernel = Y.new_zeros(batch, n_out, n_in) # note: for the normalization we assume that the variance of R[i] is one begin_R = 0 begin_out = 0 for i, (mul_out, l_out, p_out) in enumerate(Rs_out): s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1)) begin_out += mul_out * (2 * l_out + 1) begin_in = 0 for j, (mul_in, l_in, p_in) in enumerate(Rs_in): s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1)) begin_in += mul_in * (2 * l_in + 1) l_filters = selection_rule(l_in, p_in, l_out, p_out) if not l_filters: continue # extract the subset of the `R` that corresponds to the couple (l_out, l_in) n = mul_out * mul_in * len(l_filters) sub_R = R[:, begin_R:begin_R + n].reshape( batch, mul_out, mul_in, len(l_filters)) # [batch, mul_out, mul_in, l_filter] begin_R += n # note: I don't know if we can vectorize this for loop because [l_filter * m_filter] cannot be put into [l_filter, m_filter] K = 0 for k, l_filter in enumerate(l_filters): tmp = sum(2 * l + 1 for l in set_of_l_filters if l < l_filter) sub_Y = Y[:, tmp:tmp + 2 * l_filter + 1] # [batch, m] C = o3.wigner_3j(l_out, l_in, l_filter, cached=True, like=kernel) # [m_out, m_in, m] # note: The multiplication with `sub_R` could also be done outside of the for loop K += norm_coef[i, j] * torch.einsum( "ijk,zk,zuv->zuivj", (C, sub_Y, sub_R[..., k])) # [batch, mul_out, m_out, mul_in, m_in] if not isinstance(K, int): kernel[:, s_out, s_in] = K.reshape_as(kernel[:, s_out, s_in]) return kernel
def test_tensor_product_to_dense(): with o3.torch_default_dtype(torch.float64): Rs_1 = [(3, 0), (2, 1), (5, 2)] Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)] mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule) assert mul.to_dense().shape == (rs.dim(mul.Rs_out), rs.dim(Rs_1), rs.dim(Rs_2))
def test_change_lmax(self): lmax = 0 mul = 1 signal = torch.zeros(rs.dim([(mul, lmax)])) sph = sphten.SphericalTensor(signal, mul, lmax) lmax_new = 5 sph_new = sph.change_lmax(lmax_new) assert sph_new.signal.shape[0] == rs.dim(sph_new.Rs)
def change_lmax(self, lmax): new_Rs = [(self.mul, l) for l in range(lmax + 1)] if self.lmax == lmax: return self elif self.lmax > lmax: new_signal = self.signal[:rs.dim(new_Rs)] return FourierTensor(new_signal, self.mul, lmax) elif self.lmax < lmax: new_signal = torch.zeros(rs.dim(new_Rs)) new_signal[:rs.dim(self.Rs)] = self.signal return FourierTensor(new_signal, self.mul, lmax)
def change_lmax(self, lmax): new_Rs = [(1, l) for l in range(lmax + 1)] if self.lmax == lmax: return self elif self.lmax > lmax: new_signal = self.signal[..., :rs.dim(new_Rs)] return SphericalTensor(new_signal, self.p_val, self.p_arg) elif self.lmax < lmax: new_signal = torch.zeros(*self.signal.shape[:-1], rs.dim(new_Rs)) new_signal[..., :rs.dim(self.Rs)] = self.signal return SphericalTensor(new_signal, self.p_val, self.p_arg)
def forward(self, features): """ :param features: tensor [..., channel] :return: tensor [..., channel] """ size = features.shape[:-1] features = features.reshape(-1, rs.dim(self.Rs_in)) output = torch.einsum('ij,zj->zi', self.kernel(), features) return output.reshape(*size, rs.dim(self.Rs_out))
def test_group_kernel(): kernel = partial(Kernel, RadialModel=ConstantRadialModel) Rs_in = [(5, 0, 1), (4, 1, -1)] Rs_out = [(3, 0, 1), (5, 1, -1)] groups = 4 gkernel = GroupKernel(Rs_in, Rs_out, kernel, groups) N = 7 input = torch.randn(N, 3) output = gkernel(input) assert output.dim() == 4 # [N, g, cout, cin] assert tuple(output.shape) == (N, groups, rs.dim(Rs_out), rs.dim(Rs_in))
def initialize_edges(x, Rs_in, pos, edge_index_dict, lmax, self_edge=1., symmetric_edges=False): """Initialize edge features of DataEdgeNeighbors using node features and SphericalTensor. Args: x (torch.tensor shape [N, rs.dim(Rs_in)]): Node features. Rs_in (rs.TY_RS_STRICT): Representation list of input. pos (torch.tensor shape [N, 3]): Cartesian coordinates of nodes. edge_index (torch.LongTensor shape [2, num_edges]): Edges described by index of node target then node source. lmax (int > 0): Maximum L to use for SphericalTensor projection of radial distance vectors self_edge (float, optional): L=0 feature for self edges. Defaults to 1. symmetric_edges (bool, optional): Constrain edge features to be symmetric in node index. Defaults to False Returns: edge_x: Edge features. Rs_edge (rs.TY_RS_STRICT): Representation list of edge features. """ from e3nn.tensor import SphericalTensor edge_x = [] if symmetric_edges: Rs, Q = rs.reduce_tensor('ij=ji', i=Rs_in) else: Rs, Q = rs.reduce_tensor('ij', i=Rs_in, j=Rs_in) Q = Q.reshape(-1, rs.dim(Rs_in), rs.dim(Rs_in)) Rs_sph = [(1, l, (-1)**l) for l in range(lmax + 1)] tp_kernel = rs.TensorProduct(Rs, Rs_sph, o3.selection_rule) keys, values = list(zip(*edge_index_dict.items())) sorted_edges = sorted(zip(keys, values), key=lambda x: x[1]) for (target, source), _ in sorted_edges: Ia = x[target] Ib = x[source] vector = (pos[source] - pos[target]).reshape(-1, 3) if torch.allclose(vector, torch.zeros(vector.shape)): signal = torch.zeros(rs.dim(Rs_sph)) signal[0] = self_edge else: signal = SphericalTensor.from_geometry(vector, lmax=lmax).signal if symmetric_edges: signal += SphericalTensor.from_geometry(-vector, lmax=lmax).signal signal *= 0.5 output = torch.einsum('kij,i,j->k', Q, Ia, Ib) output = tp_kernel(output, signal) edge_x.append(output) edge_x = torch.stack(edge_x, dim=0) return edge_x, tp_kernel.Rs_out
def forward(self, x): for lin in self.layers: x = lin(x) x = x.reshape(*x.shape[:-1], self.mul, -1) # put multiplicity into batch x1 = x.narrow(-1, 0, rs.dim(self.act1.Rs_in)) x2 = x.narrow(-1, rs.dim(self.act1.Rs_in), rs.dim(self.act2.Rs_in)) x1 = self.act1(x1) x2 = self.act2(x2) x = torch.cat([x1, x2], dim=-1) x = x.reshape(*x.shape[:-2], -1) # put back into representation x = self.tail(x) return x
def forward(self, features_1, features_2): """ :return: tensor [..., channel] """ *size, n = features_1.size() features_1 = features_1.reshape(-1, n) assert n == rs.dim(self.Rs_in1) *size2, n = features_2.size() features_2 = features_2.reshape(-1, n) assert size == size2 T = get_sparse_buffer(self, 'T') # [out, in1 * in2] kernel = (T.t() @ self.kernel().T).T.reshape(rs.dim(self.Rs_out), rs.dim(self.Rs_in1), rs.dim(self.Rs_in2)) # [out, in1, in2] features = torch.einsum('kij,zi,zj->zk', kernel, features_1, features_2) return features.reshape(*size, -1)
def test_tensor_product_symmetry(): with o3.torch_default_dtype(torch.float64): Rs_in = [(3, 0), (2, 1), (5, 2)] Rs_out = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)] mul1 = rs.TensorProduct(Rs_in, o3.selection_rule, Rs_out) mul2 = rs.TensorProduct(o3.selection_rule, Rs_in, Rs_out) assert mul1.Rs_in2 == mul2.Rs_in1 x = torch.randn(rs.dim(Rs_in), rs.dim(mul1.Rs_in2)) y1 = mul1(x) y2 = mul2(x.T) assert (y1 - y2).abs().max() < 1e-10
def forward(self, r, r_eps=0, **_kwargs): """ :param r: tensor [..., 3] :return: tensor [..., l_out * mul_out * m_out, l_in * mul_in * m_in] """ *size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) radii = r.norm(2, dim=1) # [batch] # (1) Case r > 0 # precompute all needed spherical harmonics Y = rsh.spherical_harmonics_xyz( self.Ls, r[radii > r_eps]) # [batch, l_filter * m_filter] # Normalize the spherical harmonics if self.normalization == 'component': Y.mul_(math.sqrt(4 * math.pi)) if self.normalization == 'norm': diag = math.sqrt(4 * math.pi) * torch.cat([ torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in self.Rs_f ]) Y.mul_(diag) # use the radial model to fix all the degrees of freedom # note: for the normalization we assume that the variance of R[i] is one R = self.R(radii[radii > r_eps] ) # [batch, l_out * l_in * mul_out * mul_in * l_filter] RY = rsh.mul_radial_angular(self.Rs_f, R, Y) if Y.shape[0] == 0: kernel1 = torch.zeros(0, rs.dim(self.Rs_out), rs.dim(self.Rs_in)) else: kernel1 = self.tp.right(RY) # (2) Case r = 0 kernel2 = self.linear() kernel = r.new_zeros(len(r), *kernel2.shape) kernel[radii > r_eps] = kernel1 kernel[radii <= r_eps] = kernel2 return kernel.reshape(*size, *kernel2.shape)
def rotation_gated_block(self, K): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(K, RadialModel=ConstantRadialModel) act = GatedBlock(Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def parity_rotation_gated_block_parity(self, K): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] K = partial(K, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = rs.rep(Rs_in, *abc, 1) D_out = rs.rep(act.Rs_out, *abc, 1) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def forward(self, features): """ :param features: tensor [..., channel] :return: tensor [..., channel] """ *size, d = features.shape assert d == rs.dim(self.Rs) norms = self.norm(features) # [..., l*mul] output = [] index_features = 0 index_norms = 0 for mul, l, _ in self.Rs: v = features.narrow(-1, index_features, mul * (2 * l + 1)).reshape(*size, mul, 2 * l + 1) # [..., u, i] index_features += mul * (2 * l + 1) n = norms.narrow(-1, index_norms, mul).reshape(*size, mul, 1) # [..., u, i] b = self.bias[index_norms: index_norms + mul].reshape(mul, 1) # [u, i] index_norms += mul if l == 0: out = self.activation(v + b) else: out = self.activation(n + b) * v output.append(out.reshape(*size, mul * (2 * l + 1))) return torch.cat(output, dim=-1)
def __init__(self, tensor, Rs): Rs = rs.convention(Rs) if tensor.shape[-1] != rs.dim(Rs): raise ValueError( "Last tensor dimension and Rs do not have same dimension.") self.tensor = tensor self.Rs = Rs
def plot_data_on_grid(box_length, radial, Rs, sh=o3.spherical_harmonics_xyz, n=30): L_to_index = {} set_of_L = set([L for mul, L in Rs]) start = 0 for L in set_of_L: L_to_index[L] = [start, start + 2 * L + 1] start += 2 * L + 1 r = np.mgrid[-1:1:n * 1j, -1:1:n * 1j, -1:1:n * 1j].reshape(3, -1) r = r.transpose(1, 0) r *= box_length / 2. r = torch.from_numpy(r) Ys = sh(set_of_L, r) R = radial(r.norm(2, -1)).detach() # [r_values, n_r_filters] assert R.shape[-1] == rs.mul_dim(Rs) R_helper = torch.zeros(R.shape[-1], rs.dim(Rs)) mul_start = 0 y_start = 0 Ys_indices = [] for mul, L in Rs: Ys_indices += list(range(L_to_index[L][0], L_to_index[L][1])) * mul R_helper = rs.map_mul_to_Rs(Rs) full_Ys = Ys[Ys_indices] # [values, rs.dim(Rs)]] full_Ys = full_Ys.reshape(full_Ys.shape[0], -1) all_f = torch.einsum('xn,dn,dx->xd', R, R_helper, full_Ys) return r, all_f
def main(): representations1 = [(1, ), (3, 4, 0, 0), (8, 8, 0, 0), (8, 6, 0, 0), (64, )] representations1 = [[(mul, l) for l, mul in enumerate(rs)] for rs in representations1] representations2 = [(1, ), (2, 3, 2, 0), (6, 5, 5, 0), (6, 4, 4, 0), (64, )] representations2 = [[(mul, l) for l, mul in enumerate(rs)] for rs in representations2] representations3 = [(1, ), (2, 2, 2, 1), (4, 4, 4, 4), (6, 4, 4, 0), (64, )] representations3 = [[(mul, l) for l, mul in enumerate(rs)] for rs in representations3] representations0 = [[ (mul, 0) ] for l, mul in enumerate([dim(r) for r in representations3])] tetris, labels = get_dataset() data = [] for i, reps in enumerate([ representations0, representations1, representations2, representations3 ]): f = SE3Net(len(tetris), reps) training, _ = train(tetris, labels, f) data.append(training) return data
def check_rotation(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) Rs_in = [(1, 0), (1, 1)] Rs_out = [(1, 0), (1, 1), (1, 2)] act = GatedBlock( Rs_out, scalar_activation=sigmoid, gate_activation=absolute, ) conv = Convolution(K, Rs_in, act.Rs_in) # Setup the data. The geometry, input features, and output features must all rotate. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) feat = torch.randn(batch, n_atoms, rs.dim(Rs_in)) # Transforms with wigner D matrix geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix. # Test equivariance. F = act(conv(feat, geo)) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t())) return (RF - FR).norm() < 10e-5 * RF.norm()
def test1(self): torch.set_default_dtype(torch.float64) Rs_in = [(5, 0), (20, 1), (15, 0), (20, 2)] Rs_out = [(5, 0), (10, 1), (10, 2), (5, 0)] with torch.no_grad(): lin = Linear(Rs_in, Rs_out) features = torch.randn(10000, rs.dim(Rs_in)) features = lin(features) bins, left, right = 100, -4, 4 bin_width = (right - left) / (bins - 1) x = torch.linspace(left, right, bins) p = torch.histc(features, bins, left, right) / features.numel() / bin_width q = x.pow(2).div(-2).exp().div(math.sqrt(2 * math.pi)) # Normal law # import matplotlib.pyplot as plt # plt.plot(x, p) # plt.plot(x, q) # plt.show() Dkl = ((p + 1e-100) / q).log().mul(p).sum() # Kullback-Leibler divergence of P || Q self.assertLess(Dkl, 0.1)
def kernel_conv_fn_forward(F, Y, R, norm_coef, Rs_in, Rs_out, selection_rule, set_of_l_filters): """ :param F: tensor [batch, b, l_in * mul_in * m_in] :param Y: tensor [l_filter * m_filter, batch, a, b] :param R: tensor [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter] :param norm_coef: tensor [l_out, l_in] :return: tensor [batch, a, l_out * mul_out * m_out, l_in * mul_in * m_in] """ batch, a, b, _ = Y.shape n_out = rs.dim(Rs_out) kernel_conv = Y.new_zeros(batch, a, n_out) # note: for the normalization we assume that the variance of R[i] is one begin_R = 0 begin_out = 0 for i, (mul_out, l_out, p_out) in enumerate(Rs_out): s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1)) begin_out += mul_out * (2 * l_out + 1) begin_in = 0 for j, (mul_in, l_in, p_in) in enumerate(Rs_in): s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1)) begin_in += mul_in * (2 * l_in + 1) l_filters = selection_rule(l_in, p_in, l_out, p_out) if not l_filters: continue # extract the subset of the `R` that corresponds to the couple (l_out, l_in) n = mul_out * mul_in * len(l_filters) sub_R = R[:, :, :, begin_R:begin_R + n].reshape( batch, a, b, mul_out, mul_in, -1) # [batch, a, b, mul_out, mul_in, l_filter] begin_R += n K = 0 for k, l_filter in enumerate(l_filters): offset = sum(2 * l + 1 for l in set_of_l_filters if l < l_filter) sub_Y = Y[..., offset:offset + 2 * l_filter + 1] # [batch, a, b, m] C = o3.wigner_3j(l_out, l_in, l_filter, cached=True, like=kernel_conv) # [m_out, m_in, m] K += norm_coef[i, j] * torch.einsum( "ijk,zabk,zabuv,zbvj->zaui", C, sub_Y, sub_R[..., k], F[..., s_in].reshape( batch, b, mul_in, -1)) # [batch, a, mul_out, m_out] if not isinstance(K, int): kernel_conv[:, :, s_out] += K.reshape(batch, a, -1) return kernel_conv
def check_rotation_parity(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) Rs_in = [(1, 0, +1)] act = GatedBlockParity( Rs_scalars=[(4, 0, +1)], act_scalars=[(-1, relu)], Rs_gates=[(8, 0, +1)], act_gates=[(-1, tanh)], Rs_nonscalars=[(4, 1, -1), (4, 2, +1)] ) conv = Convolution(K, Rs_in, act.Rs_in) Rs_out = act.Rs_out # Setup the data. The geometry, input features, and output features must all rotate and observe parity. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = -o3.rot(*abc) # Negative because geometry has odd parity. i.e. improper rotation. D_in = rs.rep(Rs_in, *abc, parity=1) D_out = rs.rep(Rs_out, *abc, parity=1) feat = torch.randn(batch, n_atoms, rs.dim(Rs_in)) # Transforms with wigner D matrix and parity. geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix and parity. # Test equivariance. F = act(conv(feat, geo)) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t())) return (RF - FR).norm() < 10e-5 * RF.norm()
def forward(self, features): ''' :param features: [..., channels] ''' *size, n = features.size() features = features.reshape(-1, n) assert n == rs.dim(self.Rs_in) if self.linear: features = torch.cat([features.new_ones(features.shape[0], 1), features], dim=1) n += 1 T = get_sparse_buffer(self, 'T') # [out, in1 * in2] kernel = (T.t() @ self.kernel().T).T.reshape(rs.dim(self.Rs_out), n, n) # [out, in1, in2] features = torch.einsum('zi,zj->zij', features, features) features = torch.einsum('kij,zij->zk', kernel, features) return features.reshape(*size, -1)
def test_tensor_product(self): torch.set_default_dtype(torch.float64) Rs_1 = [(3, 0), (2, 1), (5, 2)] Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)] Rs_out, m = rs.tensor_product(Rs_1, Rs_2, o3.selection_rule) mul = TensorProduct(Rs_1, Rs_2) x1 = torch.randn(1, rs.dim(Rs_1)) x2 = torch.randn(1, rs.dim(Rs_2)) y1 = mul(x1, x2) y2 = torch.einsum('kij,zi,zj->zk', m, x1, x2) self.assertEqual(rs.dim(Rs_out), y1.shape[1]) self.assertLess((y1 - y2).abs().max(), 1e-7 * y1.abs().max())
def forward(self, features_1, features_2, weights): """ :return: tensor [..., channel] """ *size, n = features_1.size() features_1 = features_1.reshape(-1, n) assert n == rs.dim(self.Rs_in1), f"{n} is not {rs.dim(self.Rs_in1)}" *size2, n = features_2.size() features_2 = features_2.reshape(-1, n) assert n == rs.dim(self.Rs_in2), f"{n} is not {rs.dim(self.Rs_in2)}" assert size == size2 weights = weights.reshape(-1, self.nweight) wigners = [getattr(self, arg) for arg in self.wigners_names] features = self.main(*wigners, features_1, features_2, weights) return features.reshape(*size, -1)
def test(Rs, ac): x = torch.randn(99, rs.dim(Rs)) a, b = torch.rand(2) c = 1 y1 = ac(x, dim=-1) @ rs.rep(ac.Rs_out, a, b, c).T y2 = ac(x @ rs.rep(Rs, a, b, c).T, dim=-1) y3 = ac(x @ rs.rep(Rs, -c, -b, -a).T, dim=-1) self.assertLess((y1 - y2).norm(), (y1 - y3).norm())
def forward(self, x): """ :param x: [batch, x, y, z, channel_in] :return: [batch, x, y, z, channel_out] """ for conv, act, pool in self.layers: x = conv(x) x = x.reshape(*x.shape[:-1], self.mul, rs.dim(act.Rs_in)) # put multiplicity into batch x = act(x) x = x.reshape(*x.shape[:-2], self.mul * rs.dim(act.Rs_out)) # put back into representation x = pool(x) x = self.tail(x) return x