def test_tensor_product_symmetry(): with o3.torch_default_dtype(torch.float64): Rs_in = [(3, 0), (2, 1), (5, 2)] Rs_out = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)] mul1 = rs.TensorProduct(Rs_in, o3.selection_rule, Rs_out) mul2 = rs.TensorProduct(o3.selection_rule, Rs_in, Rs_out) assert mul1.Rs_in2 == mul2.Rs_in1 x = torch.randn(rs.dim(Rs_in), rs.dim(mul1.Rs_in2)) y1 = mul1(x) y2 = mul2(x.T) assert (y1 - y2).abs().max() < 1e-10
def __init__(self, Rs_in, Rs_out, RadialModel, r, r_eps=0, selection_rule=o3.selection_rule_in_out_sh, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param tensor r: [..., 3] :param float r_eps: distance considered as zero :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) *self.size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) # [batch, space] self.register_buffer('radii', r.norm(2, dim=1)) # [batch] self.r_eps = r_eps self.tp = rs.TensorProduct(self.Rs_in, selection_rule, self.Rs_out, normalization, sorted=True) self.Rs_f = self.tp.Rs_in2 Y = rsh.spherical_harmonics_xyz( [(1, l, p) for _, l, p in self.Rs_f], r[self.radii > self.r_eps]) # [batch, l_filter * m_filter] # Normalize the spherical harmonics if normalization == 'component': Y.mul_(math.sqrt(4 * math.pi)) if normalization == 'norm': diag = math.sqrt(4 * math.pi) * torch.cat([ torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in self.Rs_f ]) Y.mul_(diag) self.register_buffer('Y', Y) self.R = RadialModel(rs.mul_dim(self.Rs_f)) if (self.radii <= self.r_eps).any(): self.linear = KernelLinear(self.Rs_in, self.Rs_out) else: self.linear = None
def __init__(self, Rs_in, Rs_out, RadialModel, selection_rule=o3.selection_rule_in_out_sh, normalization='component', allow_unused_inputs=False, allow_zero_outputs=False): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) if not allow_unused_inputs: self.check_input(selection_rule) if not allow_zero_outputs: self.check_output(selection_rule) self.normalization = normalization self.tp = rs.TensorProduct(self.Rs_in, selection_rule, Rs_out, normalization, sorted=True) self.Rs_f = self.tp.Rs_in2 self.Ls = [l for _, l, _ in self.Rs_f] self.R = RadialModel(rs.mul_dim(self.Rs_f)) self.linear = KernelLinear(self.Rs_in, self.Rs_out)
def test_tensor_product_to_dense(): with o3.torch_default_dtype(torch.float64): Rs_1 = [(3, 0), (2, 1), (5, 2)] Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)] mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule) assert mul.to_dense().shape == (rs.dim(mul.Rs_out), rs.dim(Rs_1), rs.dim(Rs_2))
def __init__(self, Rs_in, Rs_out): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() selection_rule = partial(o3.selection_rule_in_out_sh, lmax=0) self.tp = rs.TensorProduct(Rs_in, selection_rule, Rs_out, sorted=False) self.weight = torch.nn.Parameter(torch.randn(rs.dim(self.tp.Rs_in2)))
def test_tensor_product_in_in_normalization_norm(Rs_in1, Rs_in2): with o3.torch_default_dtype(torch.float64): tp = rs.TensorProduct(Rs_in1, Rs_in2, o3.selection_rule, normalization='norm') x1 = rs.randn(10, Rs_in1, normalization='norm') x2 = rs.randn(10, Rs_in2, normalization='norm') n = Norm(tp.Rs_out, normalization='norm') x = n(tp(x1, x2)).mean(0) assert (x.log10().abs() < 1).all()
def initialize_edges(x, Rs_in, pos, edge_index_dict, lmax, self_edge=1., symmetric_edges=False): """Initialize edge features of DataEdgeNeighbors using node features and SphericalTensor. Args: x (torch.tensor shape [N, rs.dim(Rs_in)]): Node features. Rs_in (rs.TY_RS_STRICT): Representation list of input. pos (torch.tensor shape [N, 3]): Cartesian coordinates of nodes. edge_index (torch.LongTensor shape [2, num_edges]): Edges described by index of node target then node source. lmax (int > 0): Maximum L to use for SphericalTensor projection of radial distance vectors self_edge (float, optional): L=0 feature for self edges. Defaults to 1. symmetric_edges (bool, optional): Constrain edge features to be symmetric in node index. Defaults to False Returns: edge_x: Edge features. Rs_edge (rs.TY_RS_STRICT): Representation list of edge features. """ from e3nn.tensor import SphericalTensor edge_x = [] if symmetric_edges: Rs, Q = rs.reduce_tensor('ij=ji', i=Rs_in) else: Rs, Q = rs.reduce_tensor('ij', i=Rs_in, j=Rs_in) Q = Q.reshape(-1, rs.dim(Rs_in), rs.dim(Rs_in)) Rs_sph = [(1, l, (-1)**l) for l in range(lmax + 1)] tp_kernel = rs.TensorProduct(Rs, Rs_sph, o3.selection_rule) keys, values = list(zip(*edge_index_dict.items())) sorted_edges = sorted(zip(keys, values), key=lambda x: x[1]) for (target, source), _ in sorted_edges: Ia = x[target] Ib = x[source] vector = (pos[source] - pos[target]).reshape(-1, 3) if torch.allclose(vector, torch.zeros(vector.shape)): signal = torch.zeros(rs.dim(Rs_sph)) signal[0] = self_edge else: signal = SphericalTensor.from_geometry(vector, lmax=lmax).signal if symmetric_edges: signal += SphericalTensor.from_geometry(-vector, lmax=lmax).signal signal *= 0.5 output = torch.einsum('kij,i,j->k', Q, Ia, Ib) output = tp_kernel(output, signal) edge_x.append(output) edge_x = torch.stack(edge_x, dim=0) return edge_x, tp_kernel.Rs_out
def test_tensor_product_equal_TensorProduct(): with o3.torch_default_dtype(torch.float64): Rs_1 = [(3, 0), (2, 1), (5, 2)] Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)] Rs_out, m = rs.tensor_product(Rs_1, Rs_2, o3.selection_rule, sorted=True) mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule) x1 = rs.randn(1, Rs_1) x2 = rs.randn(1, Rs_2) y1 = mul(x1, x2) y2 = torch.einsum('zi,zj->ijz', x1, x2) y2 = (m @ y2.reshape(rs.dim(Rs_1) * rs.dim(Rs_2), -1)).T assert rs.dim(Rs_out) == y1.shape[1] assert (y1 - y2).abs().max() < 1e-10 * y1.abs().max()
def test_tensor_product_left_right(): with o3.torch_default_dtype(torch.float64): Rs_1 = [(3, 0), (2, 1), (5, 2)] Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)] mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule) x1 = rs.randn(2, Rs_1) x2 = rs.randn(2, Rs_2) y0 = mul(x1, x2) y1 = mul(torch.einsum('zi,zj->zij', x1, x2)) assert (y0 - y1).abs().max() < 1e-10 * y0.abs().max() mul._complete = 'in1' y1 = mul(x1, x2) assert (y0 - y1).abs().max() < 1e-10 * y0.abs().max() mul._complete = 'in2' y1 = mul(x1, x2) assert (y0 - y1).abs().max() < 1e-10 * y0.abs().max()
def __matmul__(self, other): # Tensor product # Better handle mismatch of features indices tp = rs.TensorProduct(self.Rs, other.Rs, o3.selection_rule) return IrrepTensor(tp(self.signal, other.signal), tp.Rs_out)