Example #1
0
    def __init__(
        self,
        irreps_in,
        irreps_out,
        internal_weights=None,
        shared_weights=None,
    ):
        super().__init__()

        irreps_in = o3.Irreps(irreps_in)
        irreps_out = o3.Irreps(irreps_out)

        instr = [(i_in, 0, i_out, "uvw", True, 1.0)
                 for i_in, (_, ir_in) in enumerate(irreps_in)
                 for i_out, (_, ir_out) in enumerate(irreps_out)
                 if ir_in == ir_out]

        self.tp = o3.TensorProduct(
            irreps_in,
            "0e",
            irreps_out,
            instr,
            internal_weights=internal_weights,
            shared_weights=shared_weights,
        )

        self.output_mask = self.tp.output_mask
        self.irreps_in = irreps_in
        self.irreps_out = irreps_out
Example #2
0
    def __init__(self, irreps_scalars, act_scalars, irreps_gates, act_gates,
                 irreps_gated):
        super().__init__()
        irreps_scalars = o3.Irreps(irreps_scalars)
        irreps_gates = o3.Irreps(irreps_gates)
        irreps_gated = o3.Irreps(irreps_gated)

        if len(irreps_gates) > 0 and irreps_gates.lmax > 0:
            raise ValueError(
                f"Gate scalars must be scalars, instead got irreps_gates = {irreps_gates}"
            )
        if len(irreps_scalars) > 0 and irreps_scalars.lmax > 0:
            raise ValueError(
                f"Scalars must be scalars, instead got irreps_scalars = {irreps_scalars}"
            )
        if irreps_gates.num_irreps != irreps_gated.num_irreps:
            raise ValueError(
                f"There are {irreps_gated.num_irreps} irreps in irreps_gated, but a different number ({irreps_gates.num_irreps}) of gate scalars in irreps_gates"
            )

        self.sc = _Sortcut(irreps_scalars, irreps_gates, irreps_gated)
        self.irreps_scalars, self.irreps_gates, self.irreps_gated = self.sc.irreps_outs
        self._irreps_in = self.sc.irreps_in

        self.act_scalars = Activation(irreps_scalars, act_scalars)
        irreps_scalars = self.act_scalars.irreps_out

        self.act_gates = Activation(irreps_gates, act_gates)
        irreps_gates = self.act_gates.irreps_out

        self.mul = o3.ElementwiseTensorProduct(irreps_gated, irreps_gates)
        irreps_gated = self.mul.irreps_out

        self._irreps_out = irreps_scalars + irreps_gated
Example #3
0
def test_bias():
    irreps_in = o3.Irreps("2x0e + 1e + 2x0e + 0o")
    irreps_out = o3.Irreps("3x0e + 1e + 3x0e + 5x0e + 0o")
    m = o3.Linear(irreps_in,
                  irreps_out,
                  biases=[True, False, False, True, False])
    with torch.no_grad():
        m.bias[:].fill_(1.0)
    x = m(torch.zeros(irreps_in.dim))

    assert torch.allclose(
        x,
        torch.tensor([
            1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
            1.0, 0.0
        ]))

    assert_equivariant(m)
    assert_auto_jitable(m)

    m = o3.Linear("0e + 0o + 1e + 1o", "10x0e + 0o + 1e + 1o", biases=True)

    assert_equivariant(m)
    assert_auto_jitable(m)
    assert_normalized(m,
                      n_weight=100,
                      n_input=10_000,
                      atol=0.5,
                      weights=[m.weight])
Example #4
0
    def __init__(self, irreps_out, num_z, lmax) -> None:
        super().__init__()
        self.num_z = num_z

        self.irreps_sh = o3.Irreps.spherical_harmonics(lmax)

        # to multiply the edge type one-hot with the spherical harmonics to get the edge attributes
        self.mul = TensorProduct(
            [(num_z**2, "0e")],
            self.irreps_sh,
            [(num_z**2, ir) for _, ir in self.irreps_sh],
            [
                (0, l, l, "uvu", False)
                for l in range(lmax + 1)
            ]
        )
        irreps_attr = self.mul.irreps_out

        irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o")
        irreps_out = o3.Irreps(irreps_out)

        self.tp1 = FullyConnectedTensorProduct(
            irreps_in1=self.irreps_sh,
            irreps_in2=irreps_attr,
            irreps_out=irreps_mid,
        )
        self.tp2 = FullyConnectedTensorProduct(
            irreps_in1=irreps_mid,
            irreps_in2=irreps_attr,
            irreps_out=irreps_out,
        )
def network():
    num_nodes = 5
    irreps_in = o3.Irreps("3x0e + 2x1o")
    irreps_attr = o3.Irreps("10x0e")
    irreps_out = o3.Irreps("2x0o + 2x1o + 2x2e")

    f = Network(
        irreps_in,
        o3.Irreps("5x0e + 5x0o + 5x1e + 5x1o"),
        irreps_out,
        irreps_attr,
        o3.Irreps.spherical_harmonics(3),
        layers=3,
        max_radius=2.0,
        number_of_basis=5,
        radial_layers=2,
        radial_neurons=100,
        num_neighbors=4.0,
        num_nodes=num_nodes,
    )

    def random_graph():
        N = random.randint(3, 7)
        return {
            'pos': torch.randn(N, 3),
            'x': f.irreps_in.randn(N, -1),
            'z': f.irreps_node_attr.randn(N, -1)
        }

    return f, random_graph
def test_full():
    irreps_in1 = o3.Irreps("1e + 2e + 3x3o")
    irreps_in2 = o3.Irreps("1e + 2x2e + 2x3o")
    m = FullTensorProduct(irreps_in1, irreps_in2)
    print(m)

    assert_equivariant(m)
    assert_auto_jitable(m)
Example #7
0
    def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr,
                 irreps_node_output, fc_neurons, num_neighbors) -> None:
        super().__init__()
        self.irreps_node_input = o3.Irreps(irreps_node_input)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
        self.irreps_node_output = o3.Irreps(irreps_node_output)
        self.num_neighbors = num_neighbors

        self.sc = FullyConnectedTensorProduct(self.irreps_node_input,
                                              self.irreps_node_attr,
                                              self.irreps_node_output)

        self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input,
                                                self.irreps_node_attr,
                                                self.irreps_node_input)

        irreps_mid = []
        instructions = []
        for i, (mul, ir_in) in enumerate(self.irreps_node_input):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_node_output or ir_out == o3.Irrep(
                            0, 1):
                        k = len(irreps_mid)
                        irreps_mid.append((mul, ir_out))
                        instructions.append((i, j, k, 'uvu', True))
        irreps_mid = o3.Irreps(irreps_mid)
        irreps_mid, p, _ = irreps_mid.sort()

        assert irreps_mid.dim > 0, f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing in irreps_node_output={self.irreps_node_output}"

        instructions = [(i_1, i_2, p[i_out], mode, train)
                        for i_1, i_2, i_out, mode, train in instructions]

        tp = TensorProduct(
            self.irreps_node_input,
            self.irreps_edge_attr,
            irreps_mid,
            instructions,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel],
                                    torch.nn.functional.silu)
        self.tp = tp

        self.lin2 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr,
                                                self.irreps_node_output)

        # inspired by https://arxiv.org/pdf/2002.10444.pdf
        self.alpha = FullyConnectedTensorProduct(irreps_mid,
                                                 self.irreps_node_attr, "0e")
        with torch.no_grad():
            self.alpha.weight.zero_()
        assert self.alpha.output_mask[
            0] == 1.0, f"irreps_mid={irreps_mid} and irreps_node_attr={self.irreps_node_attr} are not able to generate scalars"
Example #8
0
def test_linear():
    irreps_in = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = o3.Linear(irreps_in, irreps_out)
    m(torch.randn(irreps_in.dim))

    assert_equivariant(m)
    assert_auto_jitable(m)
    assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
Example #9
0
def test_cat():
    irreps = o3.Irreps("4x1e + 6x2e + 12x2o") + o3.Irreps(
        "1x1e + 2x2e + 12x4o")
    assert len(irreps) == 6
    assert irreps.ls == [1] * 4 + [2] * 6 + [2] * 12 + [1] * 1 + [2] * 2 + [
        4
    ] * 12
    assert irreps.lmax == 4
    assert irreps.num_irreps == 4 + 6 + 12 + 1 + 2 + 12
Example #10
0
def test_getitem():
    irreps = o3.Irreps("16x1e + 3e + 2e + 5o")
    assert irreps[0] == (16, o3.Irrep("1e"))
    assert irreps[3] == (1, o3.Irrep("5o"))
    assert irreps[-1] == (1, o3.Irrep("5o"))

    sliced = irreps[2:]
    assert isinstance(sliced, o3.Irreps)
    assert sliced == o3.Irreps("2e + 5o")
def test_id():
    irreps_in = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = Identity(irreps_in, irreps_out)
    print(m)
    m(torch.randn(irreps_in.dim))

    assert_equivariant(m)
    assert_auto_jitable(m, strict_shapes=False)
Example #12
0
def tp_path_exists(irreps_in1, irreps_in2, ir_out):
    irreps_in1 = o3.Irreps(irreps_in1).simplify()
    irreps_in2 = o3.Irreps(irreps_in2).simplify()
    ir_out = o3.Irrep(ir_out)

    for _, ir1 in irreps_in1:
        for _, ir2 in irreps_in2:
            if ir_out in ir1 * ir2:
                return True
    return False
Example #13
0
def test_assert_equivariant():
    def not_equivariant(x1, x2):
        return x1*x2
    not_equivariant.irreps_in1 = o3.Irreps("2x0e + 1x1e + 3x2o + 1x4e")
    not_equivariant.irreps_in2 = o3.Irreps("2x0o + 3x0o + 3x2e + 1x4o")
    not_equivariant.irreps_out = o3.Irreps("1x1e + 2x0o + 3x2e + 1x4o")
    assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_in2.dim
    assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_out.dim
    with pytest.raises(AssertionError):
        assert_equivariant(not_equivariant)
def test_fully_connected():
    irreps_in1 = o3.Irreps("1e + 2e + 3x3o")
    irreps_in2 = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)
    print(m)
    m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim))

    assert_equivariant(m)
    assert_auto_jitable(m)
Example #15
0
    def __init__(self, irreps_in, irreps_out):
        super().__init__()

        self.irreps_in = o3.Irreps(irreps_in).simplify()
        self.irreps_out = o3.Irreps(irreps_out).simplify()

        assert self.irreps_in == self.irreps_out

        output_mask = torch.cat([
            torch.ones(mul * (2 * l + 1)) for mul, (l, _p) in self.irreps_out
        ])
        self.register_buffer('output_mask', output_mask)
Example #16
0
    def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs):
        super().__init__()

        self.irreps_in = o3.Irreps(irreps_in)
        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_sh = o3.Irreps(irreps_sh)

        self.num_radial_basis = num_radial_basis

        # self-connection
        self.sc = Linear(self.irreps_in, self.irreps_out)

        # connection with neighbors
        r = diameter / 2

        s = math.floor(r / steps[0])
        x = torch.arange(-s, s + 1.0) * steps[0]

        s = math.floor(r / steps[1])
        y = torch.arange(-s, s + 1.0) * steps[1]

        s = math.floor(r / steps[2])
        z = torch.arange(-s, s + 1.0) * steps[2]

        lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1)  # [x, y, z, R^3]
        self.register_buffer('lattice', lattice)

        if 'padding' not in kwargs:
            kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3])
        self.kwargs = kwargs

        emb = soft_one_hot_linspace(
            x=lattice.norm(dim=-1),
            start=0.0,
            end=r,
            number=self.num_radial_basis,
            basis='smooth_finite',
            cutoff=True,
        )
        self.register_buffer('emb', emb)

        sh = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=lattice,
            normalize=True,
            normalization='component'
        )  # [x, y, z, irreps_sh.dim]
        self.register_buffer('sh', sh)

        self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False)

        self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel))
Example #17
0
    def __init__(self,
                 irreps: o3.Irreps,
                 act,
                 res,
                 normalization='component',
                 lmax_out=None,
                 random_rot=False):
        super().__init__()

        irreps = o3.Irreps(irreps).simplify()
        _, (_, p_val) = irreps[0]
        _, (lmax, _) = irreps[-1]
        assert all(mul == 1 for mul, _ in irreps)
        assert irreps.ls == list(range(lmax + 1))
        if all(p == p_val for _, (l, p) in irreps):
            p_arg = 1
        elif all(p == p_val * (-1)**l for _, (l, p) in irreps):
            p_arg = -1
        else:
            assert False, "the parity of the input is not well defined"
        self.irreps_in = irreps
        # the input transforms as : A_l ---> p_val * (p_arg)^l * A_l
        # the sphere signal transforms as : f(r) ---> p_val * f(p_arg * r)
        if lmax_out is None:
            lmax_out = lmax

        if p_val in (0, +1):
            self.irreps_out = o3.Irreps([(1, (l, p_val * p_arg**l))
                                         for l in range(lmax_out + 1)])
        if p_val == -1:
            x = torch.linspace(0, 10, 256)
            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = 1
                self.irreps_out = o3.Irreps([(1, (l, p_arg**l))
                                             for l in range(lmax_out + 1)])
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = -1
                self.irreps_out = o3.Irreps([(1, (l, -p_arg**l))
                                             for l in range(lmax_out + 1)])
            else:
                # p_act = 0
                raise ValueError("warning! the parity is violated")

        self.to_s2 = o3.ToS2Grid(lmax, res, normalization=normalization)
        self.from_s2 = o3.FromS2Grid(res,
                                     lmax_out,
                                     normalization=normalization,
                                     lmax_in=lmax)
        self.act = normalize2mom(act)
        self.random_rot = random_rot
Example #18
0
    def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr,
                 irreps_node_output, fc_neurons, num_neighbors) -> None:
        super().__init__()
        self.irreps_node_input = o3.Irreps(irreps_node_input)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
        self.irreps_node_output = o3.Irreps(irreps_node_output)
        self.num_neighbors = num_neighbors

        self.sc = FullyConnectedTensorProduct(self.irreps_node_input,
                                              self.irreps_node_attr,
                                              self.irreps_node_output)

        self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input,
                                                self.irreps_node_attr,
                                                self.irreps_node_input)

        irreps_mid = []
        instructions = []
        for i, (mul, ir_in) in enumerate(self.irreps_node_input):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_node_output or ir_out == o3.Irrep(
                            0, 1):
                        k = len(irreps_mid)
                        irreps_mid.append((mul, ir_out))
                        instructions.append((i, j, k, 'uvu', True))
        irreps_mid = o3.Irreps(irreps_mid)
        irreps_mid, p, _ = irreps_mid.sort()

        instructions = [(i_1, i_2, p[i_out], mode, train)
                        for i_1, i_2, i_out, mode, train in instructions]

        tp = TensorProduct(
            self.irreps_node_input,
            self.irreps_edge_attr,
            irreps_mid,
            instructions,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel],
                                    torch.nn.functional.silu)
        self.tp = tp

        self.lin2 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr,
                                                self.irreps_node_output)
        self.lin3 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr, "0e")
Example #19
0
def test_jit_trace():
    @compile_mode('trace')
    class NotTracable(torch.nn.Module):
        def forward(self, param):
            if param.shape[0] == 7:
                return torch.ones(8)
            else:
                return torch.randn(8, 3)
    not_tracable = NotTracable()
    not_tracable.irreps_in = o3.Irreps("2x0e")
    not_tracable.irreps_out = o3.Irreps("1x1o")
    # TorchScript returns some weird exceptions...
    with pytest.raises(Exception):
        assert_auto_jitable(not_tracable)
Example #20
0
    def __init__(
        self,
        irreps_in1,
        irreps_in2,
        filter_ir_out=None,
        **kwargs
    ):

        irreps_in1 = o3.Irreps(irreps_in1).simplify()
        irreps_in2 = o3.Irreps(irreps_in2).simplify()
        if filter_ir_out is not None:
            filter_ir_out = [o3.Irrep(ir) for ir in filter_ir_out]

        assert irreps_in1.num_irreps == irreps_in2.num_irreps

        irreps_in1 = list(irreps_in1)
        irreps_in2 = list(irreps_in2)

        i = 0
        while i < len(irreps_in1):
            mul_1, ir_1 = irreps_in1[i]
            mul_2, ir_2 = irreps_in2[i]

            if mul_1 < mul_2:
                irreps_in2[i] = (mul_1, ir_2)
                irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2))

            if mul_2 < mul_1:
                irreps_in1[i] = (mul_2, ir_1)
                irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1))
            i += 1

        out = []
        instr = []
        for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)):
            assert mul == mul_2
            for ir in ir_1 * ir_2:

                if filter_ir_out is not None and ir not in filter_ir_out:
                    continue

                i_out = len(out)
                out.append((mul, ir))
                instr += [
                    (i, i, i_out, 'uuu', False)
                ]

        super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs)
Example #21
0
def test_normalization(float_tolerance, instance):
    sqrt_float_tolerance = torch.sqrt(float_tolerance)

    batch, n = 20, 20
    irreps = o3.Irreps("3x0e + 4x1e")

    m = BatchNorm(irreps, normalization='norm', instance=instance)

    x = torch.randn(batch, n, irreps.dim).mul(5.0).add(10.0)
    x = m(x)

    a = x[..., :3]  # [batch, space, mul]
    assert a.mean([0, 1]).abs().max() < float_tolerance
    assert a.pow(2).mean([0, 1]).sub(1).abs().max() < sqrt_float_tolerance

    a = x[..., 3:].reshape(batch, n, 4, 3)  # [batch, space, mul, repr]
    assert a.pow(2).sum(3).mean([0, 1
                                 ]).sub(1).abs().max() < sqrt_float_tolerance

    m = BatchNorm(irreps, normalization='component', instance=instance)

    x = torch.randn(batch, n, irreps.dim).mul(5.0).add(10.0)
    x = m(x)

    a = x[..., :3]  # [batch, space, mul]
    assert a.mean([0, 1]).abs().max() < float_tolerance
    assert a.pow(2).mean([0, 1]).sub(1).abs().max() < sqrt_float_tolerance

    a = x[..., 3:].reshape(batch, n, 4, 3)  # [batch, space, mul, repr]
    assert a.pow(2).mean(3).mean([0, 1
                                  ]).sub(1).abs().max() < sqrt_float_tolerance
def test_weird_irreps():
    # string input
    o3.spherical_harmonics("0e + 1o", torch.randn(1, 3), False)

    # Weird multipliciteis
    irreps = o3.Irreps("1x0e + 4x1o + 3x2e")
    out = o3.spherical_harmonics(irreps, torch.randn(7, 3), True)
    assert out.shape[-1] == irreps.dim

    # Bad parity
    with pytest.raises(ValueError):
        # L = 1 shouldn't be even for a vector input
        o3.SphericalHarmonics(
            irreps_out="1x0e + 4x1e + 3x2e",
            normalize=True,
            normalization='integral',
            irreps_in="1o",
        )

    # Good parity but psuedovector input
    _ = o3.SphericalHarmonics(irreps_in="1e",
                              irreps_out="1x0e + 4x1e + 3x2e",
                              normalize=True)

    # Invalid input
    with pytest.raises(ValueError):
        _ = o3.SphericalHarmonics(
            irreps_in="1e + 3o",  # invalid
            irreps_out="1x0e + 4x1e + 3x2e",
            normalize=True)
Example #23
0
    def __init__(
        self,
        muls,
        sh_lmax,
        num_layers,
        max_radius,
        num_basis,
        fc_neurons,
        num_neighbors,
        num_nodes,
        atomref=None,
    ) -> None:
        super().__init__()

        self.sh_lmax = sh_lmax
        self.max_radius = max_radius
        self.num_basis = num_basis
        self.num_nodes = num_nodes

        self.register_buffer('atomref', atomref)

        irreps_node_hidden = o3.Irreps([(mul, (l, p))
                                        for l, mul in enumerate(muls)
                                        for p in [-1, 1]])

        self.mp = MessagePassing(
            irreps_node_input="0e",
            irreps_node_hidden=irreps_node_hidden,
            irreps_node_output="0e + 0o",
            irreps_node_attr="5x0e",
            irreps_edge_attr=o3.Irreps.spherical_harmonics(sh_lmax),
            layers=num_layers,
            fc_neurons=[self.num_basis] + fc_neurons,
            num_neighbors=num_neighbors,
        )
def test_module(normalization, normalize):
    l = o3.Irreps("0e + 1o + 3o")
    sp = o3.SphericalHarmonics(l, normalize, normalization)
    sp_jit = assert_auto_jitable(sp)
    xyz = torch.randn(11, 3)
    assert torch.allclose(
        sp_jit(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization))
    assert_equivariant(sp)
Example #25
0
    def __init__(self) -> None:
        super().__init__()
        self.irreps_sh = o3.Irreps.spherical_harmonics(3)
        irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o")
        irreps_out = o3.Irreps("0o + 6x0e")

        self.tp1 = FullyConnectedTensorProduct(
            irreps_in1=self.irreps_sh,
            irreps_in2=self.irreps_sh,
            irreps_out=irreps_mid,
        )
        self.tp2 = FullyConnectedTensorProduct(
            irreps_in1=irreps_mid,
            irreps_in2=self.irreps_sh,
            irreps_out=irreps_out,
        )
        self.irreps_out = self.tp2.irreps_out
Example #26
0
    def forward(self, data: Union[Data, Dict[str,
                                             torch.Tensor]]) -> torch.Tensor:
        """evaluate the network

        Parameters
        ----------
        data : `torch_geometric.data.Data` or dict
            data object containing
            - ``pos`` the position of the nodes (atoms)
            - ``x`` the input features of the nodes, optional
            - ``z`` the attributes of the nodes, for instance the atom type, optional
            - ``batch`` the graph to which the node belong, optional
        """
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0],
                                          dtype=torch.long)

        edge_index = radius_graph(data['pos'], self.max_radius, batch)
        edge_src = edge_index[0]
        edge_dst = edge_index[1]
        edge_vec = data['pos'][edge_src] - data['pos'][edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_edge_attr,
                                         edge_vec,
                                         True,
                                         normalization='component')
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedded = soft_one_hot_linspace(
            x=edge_length,
            start=0.0,
            end=self.max_radius,
            number=self.number_of_basis,
            basis='gaussian',
            cutoff=False).mul(self.number_of_basis**0.5)
        edge_attr = smooth_cutoff(
            edge_length / self.max_radius)[:, None] * edge_sh

        if self.input_has_node_in and 'x' in data:
            assert self.irreps_in is not None
            x = data['x']
        else:
            assert self.irreps_in is None
            x = data['pos'].new_ones((data['pos'].shape[0], 1))

        if self.input_has_node_attr and 'z' in data:
            z = data['z']
        else:
            assert self.irreps_node_attr == o3.Irreps("0e")
            z = data['pos'].new_ones((data['pos'].shape[0], 1))

        for lay in self.layers:
            x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded)

        if self.reduce_output:
            return scatter(x, batch, dim=0).div(self.num_nodes**0.5)
        else:
            return x
Example #27
0
def test_arithmetic():
    assert 3 * o3.Irrep("6o") == o3.Irreps("3x6o")
    products = list(o3.Irrep("1o") * o3.Irrep("2e"))
    assert products == [o3.Irrep("1o"), o3.Irrep("2o"), o3.Irrep("3o")]

    assert o3.Irrep("4o") + o3.Irrep("7e") == o3.Irreps("4o + 7e")

    assert 2 * o3.Irreps("2x2e + 4x1o") == o3.Irreps(
        "2x2e + 4x1o + 2x2e + 4x1o")
    assert o3.Irreps("2x2e + 4x1o") * 2 == o3.Irreps(
        "2x2e + 4x1o + 2x2e + 4x1o")

    assert o3.Irreps("1o + 4o") + o3.Irreps("1o + 7e") == o3.Irreps(
        "1o + 4o + 1o + 7e")
Example #28
0
    def __init__(self, irreps_in, irreps_out, irreps_sh, dim_key):
        super().__init__()
        self.irreps_in = irreps_in.simplify()
        self.irreps_out = irreps_out.simplify()
        self.irreps_sh = irreps_sh.simplify()

        # self.si = Linear(self.irreps_in, self.irreps_out, internal_weights=True, shared_weights=True)
        self.si = FullyConnectedTensorProduct(self.irreps_in,
                                              o3.Irreps("5x0e"),
                                              self.irreps_out)

        # self.lin1 = Linear(self.irreps_in, self.irreps_in, internal_weights=True, shared_weights=True)
        self.lin1 = FullyConnectedTensorProduct(self.irreps_in,
                                                o3.Irreps("5x0e"),
                                                self.irreps_in)

        instr = []
        irreps = []
        for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in):
            for i_2, (_, (l_2, p_2)) in enumerate(self.irreps_sh):
                for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1):
                    p_out = p_1 * p_2
                    if (l_out, p_out) in [(l, p)
                                          for _, (l, p) in self.irreps_out]:
                        r = (mul_1, l_out, p_out)
                        if r in irreps:
                            i_out = irreps.index(r)
                        else:
                            i_out = len(irreps)
                            irreps.append(r)
                        instr += [(i_1, i_2, i_out, 'uvu', True)]
        irreps = o3.Irreps(irreps)
        self.tp = TensorProduct(self.irreps_in,
                                self.irreps_sh,
                                irreps,
                                instr,
                                internal_weights=False,
                                shared_weights=False)

        self.tp_weight = torch.nn.Parameter(
            torch.randn(dim_key, self.tp.weight_numel))

        # self.lin2 = Linear(irreps, self.irreps_out, internal_weights=True, shared_weights=True)
        self.lin2 = FullyConnectedTensorProduct(irreps, o3.Irreps("5x0e"),
                                                self.irreps_out)
Example #29
0
    def __init__(self, *irreps_outs):
        super().__init__()
        self.irreps_outs = tuple(
            o3.Irreps(irreps).simplify() for irreps in irreps_outs)
        irreps_in = sum(self.irreps_outs, o3.Irreps([]))

        i = 0
        instructions = []
        for irreps_out in self.irreps_outs:
            instructions += [tuple(range(i, i + len(irreps_out)))]
            i += len(irreps_out)
        assert len(irreps_in) == i, (len(irreps_in), i)

        irreps_in, p, _ = irreps_in.sort()
        instructions = [tuple(p[i] for i in x) for x in instructions]

        self.cut = Extract(irreps_in, self.irreps_outs, instructions)
        self.irreps_in = irreps_in.simplify()
Example #30
0
    def __init__(self, irreps_in, squared: bool = False):
        super().__init__()

        irreps_in = o3.Irreps(irreps_in).simplify()
        irreps_out = o3.Irreps([(mul, "0e") for mul, _ in irreps_in])

        instr = [(i, i, i, 'uuu', False, ir.dim)
                 for i, (mul, ir) in enumerate(irreps_in)]

        self.tp = o3.TensorProduct(irreps_in,
                                   irreps_in,
                                   irreps_out,
                                   instr,
                                   normalization='component')

        self.irreps_in = irreps_in
        self.irreps_out = irreps_out.simplify()
        self.squared = squared