Example #1
0
def test_single_out():
    tp1 = TensorProduct("5x0e", "5x0e", "5x0e", [(0, 0, 0, "uvw", True, 1.0)])
    tp2 = TensorProduct("5x0e", "5x0e", "5x0e + 3x0o",
                        [(0, 0, 0, "uvw", True, 1.0)])
    with torch.no_grad():
        tp2.weight[:] = tp1.weight
    x1, x2 = torch.randn(3, 5), torch.randn(3, 5)
    out1 = tp1(x1, x2)
    out2 = tp2(x1, x2)
    assert out1.shape == (3, 5)
    assert out2.shape == (3, 8)
    assert torch.allclose(out1, out2[:, :5])
    assert torch.all(out2[:, 5:] == 0)
Example #2
0
def test_empty():
    m = TensorProduct(
        "0x0e + 1o + 2e",
        "0e + 1o + 2e",
        "0x0e + 1o",
        [
            (0, 0, 0, "uvw", True),
            (1, 1, 0, "uvw", True),
        ],
    )
    x1, x2 = m.irreps_in1.randn(4, -1), m.irreps_in2.randn(4, -1)
    out = m(x1, x2)
    assert out.shape == (4, m.irreps_out.dim)
    assert torch.all(out == 0.0)  # no instruction leads to the 1o output
    m.right(x2)
Example #3
0
def make_tp(l1,
            p1,
            l2,
            p2,
            lo,
            po,
            mode,
            weight,
            mul: int = 25,
            path_weights: bool = True,
            **kwargs):
    def mul_out(mul):
        if mode == "uvuv":
            return mul**2
        return mul

    try:
        return TensorProduct(
            [(mul, (l1, p1)), (19, (l1, p1))], [(mul, (l2, p2)),
                                                (19, (l2, p2))],
            [(mul_out(mul), (lo, po)), (mul_out(19), (lo, po))], [
                (0, 0, 0, mode, weight),
                (1, 1, 1, mode, weight),
                (0, 0, 1, 'uvw', True, 0.5 if path_weights else 1.0),
                (0, 1, 1, 'uvw', True, 0.2 if path_weights else 1.0),
            ], **kwargs)
    except AssertionError:
        return None
Example #4
0
    def __init__(self, irreps_out, num_z, lmax) -> None:
        super().__init__()
        self.num_z = num_z

        self.irreps_sh = o3.Irreps.spherical_harmonics(lmax)

        # to multiply the edge type one-hot with the spherical harmonics to get the edge attributes
        self.mul = TensorProduct(
            [(num_z**2, "0e")],
            self.irreps_sh,
            [(num_z**2, ir) for _, ir in self.irreps_sh],
            [
                (0, l, l, "uvu", False)
                for l in range(lmax + 1)
            ]
        )
        irreps_attr = self.mul.irreps_out

        irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o")
        irreps_out = o3.Irreps(irreps_out)

        self.tp1 = FullyConnectedTensorProduct(
            irreps_in1=self.irreps_sh,
            irreps_in2=irreps_attr,
            irreps_out=irreps_mid,
        )
        self.tp2 = FullyConnectedTensorProduct(
            irreps_in1=irreps_mid,
            irreps_in2=irreps_attr,
            irreps_out=irreps_out,
        )
Example #5
0
    def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr,
                 irreps_node_output, fc_neurons, num_neighbors) -> None:
        super().__init__()
        self.irreps_node_input = o3.Irreps(irreps_node_input)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
        self.irreps_node_output = o3.Irreps(irreps_node_output)
        self.num_neighbors = num_neighbors

        self.sc = FullyConnectedTensorProduct(self.irreps_node_input,
                                              self.irreps_node_attr,
                                              self.irreps_node_output)

        self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input,
                                                self.irreps_node_attr,
                                                self.irreps_node_input)

        irreps_mid = []
        instructions = []
        for i, (mul, ir_in) in enumerate(self.irreps_node_input):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_node_output or ir_out == o3.Irrep(
                            0, 1):
                        k = len(irreps_mid)
                        irreps_mid.append((mul, ir_out))
                        instructions.append((i, j, k, 'uvu', True))
        irreps_mid = o3.Irreps(irreps_mid)
        irreps_mid, p, _ = irreps_mid.sort()

        assert irreps_mid.dim > 0, f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing in irreps_node_output={self.irreps_node_output}"

        instructions = [(i_1, i_2, p[i_out], mode, train)
                        for i_1, i_2, i_out, mode, train in instructions]

        tp = TensorProduct(
            self.irreps_node_input,
            self.irreps_edge_attr,
            irreps_mid,
            instructions,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel],
                                    torch.nn.functional.silu)
        self.tp = tp

        self.lin2 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr,
                                                self.irreps_node_output)

        # inspired by https://arxiv.org/pdf/2002.10444.pdf
        self.alpha = FullyConnectedTensorProduct(irreps_mid,
                                                 self.irreps_node_attr, "0e")
        with torch.no_grad():
            self.alpha.weight.zero_()
        assert self.alpha.output_mask[
            0] == 1.0, f"irreps_mid={irreps_mid} and irreps_node_attr={self.irreps_node_attr} are not able to generate scalars"
Example #6
0
    def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr,
                 irreps_node_output, fc_neurons, num_neighbors) -> None:
        super().__init__()
        self.irreps_node_input = o3.Irreps(irreps_node_input)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
        self.irreps_node_output = o3.Irreps(irreps_node_output)
        self.num_neighbors = num_neighbors

        self.sc = FullyConnectedTensorProduct(self.irreps_node_input,
                                              self.irreps_node_attr,
                                              self.irreps_node_output)

        self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input,
                                                self.irreps_node_attr,
                                                self.irreps_node_input)

        irreps_mid = []
        instructions = []
        for i, (mul, ir_in) in enumerate(self.irreps_node_input):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_node_output or ir_out == o3.Irrep(
                            0, 1):
                        k = len(irreps_mid)
                        irreps_mid.append((mul, ir_out))
                        instructions.append((i, j, k, 'uvu', True))
        irreps_mid = o3.Irreps(irreps_mid)
        irreps_mid, p, _ = irreps_mid.sort()

        instructions = [(i_1, i_2, p[i_out], mode, train)
                        for i_1, i_2, i_out, mode, train in instructions]

        tp = TensorProduct(
            self.irreps_node_input,
            self.irreps_edge_attr,
            irreps_mid,
            instructions,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel],
                                    torch.nn.functional.silu)
        self.tp = tp

        self.lin2 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr,
                                                self.irreps_node_output)
        self.lin3 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr, "0e")
Example #7
0
def test_specialized_code(normalization, mode, weighted, float_tolerance):
    irreps_in1 = Irreps('4x0e + 4x1e + 4x2e')
    irreps_in2 = Irreps('5x0e + 5x1e + 5x2e')
    irreps_out = Irreps('6x0e + 6x1e + 6x2e')

    if mode == 'uvu':
        irreps_out = irreps_in1
    elif mode == 'uvv':
        irreps_out = irreps_in2
    elif mode == 'uuu':
        irreps_in2 = irreps_in1
        irreps_out = irreps_in1
    elif mode == 'uuw':
        irreps_in2 = irreps_in1
        # When unweighted, uuw is a plain sum over u and requires an output mul of 1
        if not weighted:
            irreps_out = Irreps([(1, ir) for _, ir in irreps_out])

    ins = [
        (0, 0, 0, mode, weighted, 1.0),
        (0, 1, 1, mode, weighted, 1.0),
        (1, 0, 1, mode, weighted, 1.0),
        (1, 1, 0, mode, weighted, 1.0),
        (1, 1, 1, mode, weighted, 1.0),
        (0, 2, 2, mode, weighted, 1.0),
        (2, 0, 2, mode, weighted, 1.0),
        (2, 2, 0, mode, weighted, 1.0),
        (2, 1, 1, mode, weighted, 1.0),
    ]
    tp1 = TensorProduct(irreps_in1,
                        irreps_in2,
                        irreps_out,
                        ins,
                        normalization=normalization,
                        _specialized_code=False)
    tp2 = TensorProduct(irreps_in1,
                        irreps_in2,
                        irreps_out,
                        ins,
                        normalization=normalization,
                        _specialized_code=True)
    with torch.no_grad():
        tp2.weight[:] = tp1.weight

    x = irreps_in1.randn(3, -1)
    y = irreps_in2.randn(3, -1)
    assert (tp1(x, y) - tp2(x, y)).abs().max() < float_tolerance
    assert (tp1.right(y) - tp2.right(y)).abs().max() < float_tolerance
Example #8
0
    def __init__(self, irreps_in, irreps_out, irreps_sh, dim_key):
        super().__init__()
        self.irreps_in = irreps_in.simplify()
        self.irreps_out = irreps_out.simplify()
        self.irreps_sh = irreps_sh.simplify()

        # self.si = Linear(self.irreps_in, self.irreps_out, internal_weights=True, shared_weights=True)
        self.si = FullyConnectedTensorProduct(self.irreps_in,
                                              o3.Irreps("5x0e"),
                                              self.irreps_out)

        # self.lin1 = Linear(self.irreps_in, self.irreps_in, internal_weights=True, shared_weights=True)
        self.lin1 = FullyConnectedTensorProduct(self.irreps_in,
                                                o3.Irreps("5x0e"),
                                                self.irreps_in)

        instr = []
        irreps = []
        for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in):
            for i_2, (_, (l_2, p_2)) in enumerate(self.irreps_sh):
                for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1):
                    p_out = p_1 * p_2
                    if (l_out, p_out) in [(l, p)
                                          for _, (l, p) in self.irreps_out]:
                        r = (mul_1, l_out, p_out)
                        if r in irreps:
                            i_out = irreps.index(r)
                        else:
                            i_out = len(irreps)
                            irreps.append(r)
                        instr += [(i_1, i_2, i_out, 'uvu', True)]
        irreps = o3.Irreps(irreps)
        self.tp = TensorProduct(self.irreps_in,
                                self.irreps_sh,
                                irreps,
                                instr,
                                internal_weights=False,
                                shared_weights=False)

        self.tp_weight = torch.nn.Parameter(
            torch.randn(dim_key, self.tp.weight_numel))

        # self.lin2 = Linear(irreps, self.irreps_out, internal_weights=True, shared_weights=True)
        self.lin2 = FullyConnectedTensorProduct(irreps, o3.Irreps("5x0e"),
                                                self.irreps_out)