def test_single_out(): tp1 = TensorProduct("5x0e", "5x0e", "5x0e", [(0, 0, 0, "uvw", True, 1.0)]) tp2 = TensorProduct("5x0e", "5x0e", "5x0e + 3x0o", [(0, 0, 0, "uvw", True, 1.0)]) with torch.no_grad(): tp2.weight[:] = tp1.weight x1, x2 = torch.randn(3, 5), torch.randn(3, 5) out1 = tp1(x1, x2) out2 = tp2(x1, x2) assert out1.shape == (3, 5) assert out2.shape == (3, 8) assert torch.allclose(out1, out2[:, :5]) assert torch.all(out2[:, 5:] == 0)
def test_empty(): m = TensorProduct( "0x0e + 1o + 2e", "0e + 1o + 2e", "0x0e + 1o", [ (0, 0, 0, "uvw", True), (1, 1, 0, "uvw", True), ], ) x1, x2 = m.irreps_in1.randn(4, -1), m.irreps_in2.randn(4, -1) out = m(x1, x2) assert out.shape == (4, m.irreps_out.dim) assert torch.all(out == 0.0) # no instruction leads to the 1o output m.right(x2)
def make_tp(l1, p1, l2, p2, lo, po, mode, weight, mul: int = 25, path_weights: bool = True, **kwargs): def mul_out(mul): if mode == "uvuv": return mul**2 return mul try: return TensorProduct( [(mul, (l1, p1)), (19, (l1, p1))], [(mul, (l2, p2)), (19, (l2, p2))], [(mul_out(mul), (lo, po)), (mul_out(19), (lo, po))], [ (0, 0, 0, mode, weight), (1, 1, 1, mode, weight), (0, 0, 1, 'uvw', True, 0.5 if path_weights else 1.0), (0, 1, 1, 'uvw', True, 0.2 if path_weights else 1.0), ], **kwargs) except AssertionError: return None
def __init__(self, irreps_out, num_z, lmax) -> None: super().__init__() self.num_z = num_z self.irreps_sh = o3.Irreps.spherical_harmonics(lmax) # to multiply the edge type one-hot with the spherical harmonics to get the edge attributes self.mul = TensorProduct( [(num_z**2, "0e")], self.irreps_sh, [(num_z**2, ir) for _, ir in self.irreps_sh], [ (0, l, l, "uvu", False) for l in range(lmax + 1) ] ) irreps_attr = self.mul.irreps_out irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o") irreps_out = o3.Irreps(irreps_out) self.tp1 = FullyConnectedTensorProduct( irreps_in1=self.irreps_sh, irreps_in2=irreps_attr, irreps_out=irreps_mid, ) self.tp2 = FullyConnectedTensorProduct( irreps_in1=irreps_mid, irreps_in2=irreps_attr, irreps_out=irreps_out, )
def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep( 0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, 'uvu', True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() assert irreps_mid.dim > 0, f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing in irreps_node_output={self.irreps_node_output}" instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) # inspired by https://arxiv.org/pdf/2002.10444.pdf self.alpha = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e") with torch.no_grad(): self.alpha.weight.zero_() assert self.alpha.output_mask[ 0] == 1.0, f"irreps_mid={irreps_mid} and irreps_node_attr={self.irreps_node_attr} are not able to generate scalars"
def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep( 0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, 'uvu', True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) self.lin3 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e")
def test_specialized_code(normalization, mode, weighted, float_tolerance): irreps_in1 = Irreps('4x0e + 4x1e + 4x2e') irreps_in2 = Irreps('5x0e + 5x1e + 5x2e') irreps_out = Irreps('6x0e + 6x1e + 6x2e') if mode == 'uvu': irreps_out = irreps_in1 elif mode == 'uvv': irreps_out = irreps_in2 elif mode == 'uuu': irreps_in2 = irreps_in1 irreps_out = irreps_in1 elif mode == 'uuw': irreps_in2 = irreps_in1 # When unweighted, uuw is a plain sum over u and requires an output mul of 1 if not weighted: irreps_out = Irreps([(1, ir) for _, ir in irreps_out]) ins = [ (0, 0, 0, mode, weighted, 1.0), (0, 1, 1, mode, weighted, 1.0), (1, 0, 1, mode, weighted, 1.0), (1, 1, 0, mode, weighted, 1.0), (1, 1, 1, mode, weighted, 1.0), (0, 2, 2, mode, weighted, 1.0), (2, 0, 2, mode, weighted, 1.0), (2, 2, 0, mode, weighted, 1.0), (2, 1, 1, mode, weighted, 1.0), ] tp1 = TensorProduct(irreps_in1, irreps_in2, irreps_out, ins, normalization=normalization, _specialized_code=False) tp2 = TensorProduct(irreps_in1, irreps_in2, irreps_out, ins, normalization=normalization, _specialized_code=True) with torch.no_grad(): tp2.weight[:] = tp1.weight x = irreps_in1.randn(3, -1) y = irreps_in2.randn(3, -1) assert (tp1(x, y) - tp2(x, y)).abs().max() < float_tolerance assert (tp1.right(y) - tp2.right(y)).abs().max() < float_tolerance
def __init__(self, irreps_in, irreps_out, irreps_sh, dim_key): super().__init__() self.irreps_in = irreps_in.simplify() self.irreps_out = irreps_out.simplify() self.irreps_sh = irreps_sh.simplify() # self.si = Linear(self.irreps_in, self.irreps_out, internal_weights=True, shared_weights=True) self.si = FullyConnectedTensorProduct(self.irreps_in, o3.Irreps("5x0e"), self.irreps_out) # self.lin1 = Linear(self.irreps_in, self.irreps_in, internal_weights=True, shared_weights=True) self.lin1 = FullyConnectedTensorProduct(self.irreps_in, o3.Irreps("5x0e"), self.irreps_in) instr = [] irreps = [] for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in): for i_2, (_, (l_2, p_2)) in enumerate(self.irreps_sh): for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1): p_out = p_1 * p_2 if (l_out, p_out) in [(l, p) for _, (l, p) in self.irreps_out]: r = (mul_1, l_out, p_out) if r in irreps: i_out = irreps.index(r) else: i_out = len(irreps) irreps.append(r) instr += [(i_1, i_2, i_out, 'uvu', True)] irreps = o3.Irreps(irreps) self.tp = TensorProduct(self.irreps_in, self.irreps_sh, irreps, instr, internal_weights=False, shared_weights=False) self.tp_weight = torch.nn.Parameter( torch.randn(dim_key, self.tp.weight_numel)) # self.lin2 = Linear(irreps, self.irreps_out, internal_weights=True, shared_weights=True) self.lin2 = FullyConnectedTensorProduct(irreps, o3.Irreps("5x0e"), self.irreps_out)