def __mul__(self, other): # Dot product if Rs of both objects match lmax = max(self.lmax, other.lmax) new_self = self.change_lmax(lmax) new_other = other.change_lmax(lmax) mult = new_self.signal * new_other.signal mapping_matrix = rs.map_mul_to_Rs(new_self.Rs) scalars = torch.einsum('rm,...r->...m', mapping_matrix, mult) Rs = [(1, 0, p1 * p2) for (_, l1, p1), (_, l2, p2) in zip(new_self.Rs, new_other.Rs)] return IrrepTensor(scalars, Rs)
def test_from_irrep_tensor(): irrep = IrrepTensor(torch.randn(6), Rs=[(2, 1, 0)]) try: SphericalTensor.from_irrep_tensor(irrep) except: pass # Exception was raised else: raise AssertionError("ValueError was not raised.") irrep = IrrepTensor(torch.randn(6), Rs=[(1, 1, 1), (1, 1, -1)]) try: SphericalTensor.from_irrep_tensor(irrep) except: pass # Exception was raised else: raise AssertionError("ValueError was not raised.") irrep = IrrepTensor(torch.ones(8), Rs=[(1, 0, 0), (1, 3, 0)]) sph = SphericalTensor.from_irrep_tensor(irrep) compare = torch.zeros(16) compare[0] = 1. compare[-7:] = 1. assert torch.allclose(sph.signal, compare)
def __matmul__(self, other): # Tensor product # Better handle mismatch of features indices tp = rs.TensorProduct(self.Rs, other.Rs, o3.selection_rule) return IrrepTensor(tp(self.signal, other.signal), tp.Rs_out)
def to_irrep_tensor(self): Rs_out, Q = self.to_irrep_transformation() tensor = torch.einsum('ab,b->a', Q, self.tensor.reshape(-1)) return IrrepTensor(tensor, Rs_out)