示例#1
0
    def __init__(self, Rs, acts):
        '''
        Can be used only with scalar fields

        :param acts: list of tuple (multiplicity, activation)
        '''
        super().__init__()

        Rs = o3.simplify(Rs)

        n1 = sum(mul for mul, _, _ in Rs)
        n2 = sum(mul for mul, _ in acts if mul > 0)

        # normalize the second moment
        acts = [(mul, normalize2mom(act)) for mul, act in acts]

        for i, (mul, act) in enumerate(acts):
            if mul == -1:
                acts[i] = (n1 - n2, act)
                assert n1 - n2 >= 0

        assert n1 == sum(mul for mul, _ in acts)

        i = 0
        while i < len(Rs):
            mul_r, l, p_r = Rs[i]
            mul_a, act = acts[i]

            if mul_r < mul_a:
                acts[i] = (mul_r, act)
                acts.insert(i + 1, (mul_a - mul_r, act))

            if mul_a < mul_r:
                Rs[i] = (mul_a, l, p_r)
                Rs.insert(i + 1, (mul_r - mul_a, l, p_r))
            i += 1

        x = torch.linspace(0, 10, 256)

        Rs_out = []
        for (mul, l, p_in), (mul_a, act) in zip(Rs, acts):
            assert mul == mul_a

            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                p_act = 1
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                p_act = -1
            else:
                p_act = 0

            p = p_act if p_in == -1 else p_in
            Rs_out.append((mul, 0, p))

            if p_in != 0 and p == 0:
                raise ValueError("warning! the parity is violated")

        self.Rs_out = o3.simplify(Rs_out)
        self.acts = acts
示例#2
0
 def __repr__(self):
     return "{name}({Rs_in1} x {Rs_in2} -> {Rs_out} {nw} weights)".format(
         name=self.__class__.__name__,
         Rs_in1=o3.format_Rs(o3.simplify(self.Rs_in1)),
         Rs_in2=o3.format_Rs(o3.simplify(self.Rs_in2)),
         Rs_out=o3.format_Rs(o3.simplify(self.Rs_out)),
         nw=self.nweight,
     )
示例#3
0
    def __init__(self, *Rs_outs):
        super().__init__()
        self.Rs_outs = tuple(o3.simplify(Rs) for Rs in Rs_outs)

        def key(rs):
            _mul, l, p = rs
            return (l, p)

        self.Rs_in = o3.simplify(
            sorted((x for Rs in self.Rs_outs for x in Rs), key=key))
示例#4
0
    def __init__(self, Rs_in, Rs_out):
        super().__init__()

        self.Rs_in = o3.simplify(Rs_in)
        self.Rs_out = o3.simplify(Rs_out)

        assert self.Rs_in == self.Rs_out

        output_mask = torch.cat(
            [torch.ones(mul * (2 * l + 1)) for mul, l, p in self.Rs_out])
        self.register_buffer('output_mask', output_mask)
示例#5
0
def WeightedTensorProduct(Rs_in1,
                          Rs_in2,
                          Rs_out,
                          normalization='component',
                          own_weight=True,
                          weight_batch=False):
    Rs_in1 = o3.simplify(Rs_in1)
    Rs_in2 = o3.simplify(Rs_in2)
    Rs_out = o3.simplify(Rs_out)

    instr = [(i_1, i_2, i_out, 'uvw', True)
             for i_1, (_, l_1, p_1) in enumerate(Rs_in1)
             for i_2, (_, l_2, p_2) in enumerate(Rs_in2)
             for i_out, (_, l_out, p_out) in enumerate(Rs_out)
             if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out]
    return CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr,
                                       normalization, own_weight, weight_batch)
示例#6
0
def spherical_harmonics_alpha_z_y(Rs, alpha, z, y):
    """
    spherical harmonics
    """
    Rs = o3.simplify(Rs)
    sha = spherical_harmonics_alpha(o3.lmax(Rs), alpha.flatten())  # [z, m]
    shz = spherical_harmonics_z(Rs, z.flatten(), y.flatten())  # [z, l * m]
    out = mul_m_lm(Rs, sha, shz)
    return out.reshape(alpha.shape + (shz.shape[1], ))
示例#7
0
def spherical_harmonics_z(Rs, z, y=None):
    """
    the z component of the spherical harmonics
    (useful to perform fourier transform)

    :param z: tensor of shape [...]
    :return: tensor of shape [..., l * m]
    """
    Rs = o3.simplify(Rs)
    for _, l, p in Rs:
        assert p in [0, (-1)**l]
    ls = [l for mul, l, _ in Rs]
    return legendre(ls, z, y)  # [..., l * m]
示例#8
0
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 Rs_sh,
                 rad_hs,
                 normalization='component'):
        super().__init__(aggr='add')
        self.Rs_in = o3.simplify(Rs_in)
        self.Rs_out = o3.simplify(Rs_out)
        self.Rs_sh = o3.simplify(Rs_sh)

        self.si = Linear(self.Rs_in, self.Rs_out)
        self.lin1 = Linear(self.Rs_in, self.Rs_in)

        instr = []
        Rs = []
        for i_1, (mul_1, l_1, p_1) in enumerate(self.Rs_in):
            for i_2, (_, l_2, p_2) in enumerate(self.Rs_sh):
                for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1):
                    p_out = p_1 * p_2
                    if (l_out, p_out) in [(l, p) for _, l, p in self.Rs_out]:
                        r = (mul_1, l_out, p_out)
                        if r in Rs:
                            i_out = Rs.index(r)
                        else:
                            i_out = len(Rs)
                            Rs.append(r)
                        instr += [(i_1, i_2, i_out, 'uvu', True)]
        self.tp = CustomWeightedTensorProduct(self.Rs_in,
                                              self.Rs_sh,
                                              Rs,
                                              instr,
                                              own_weight=False,
                                              weight_batch=True)
        self.nn = FC(rad_hs + (self.tp.nweight, ), swish)
        self.lin2 = Linear(Rs, self.Rs_out)

        self.normalization = normalization
示例#9
0
def rep(Rs, alpha, beta, gamma, parity=None):
    """
    Representation of O(3). Parity applied (-1)**parity times.
    """
    Rs = o3.simplify(Rs)
    if parity is None:
        return direct_sum(*[
            irrep(l, alpha, beta, gamma) for mul, l, _ in Rs
            for _ in range(mul)
        ])
    else:
        assert all(parity != 0 for _, _, parity in Rs)
        return direct_sum(*[(p**parity) * irrep(l, alpha, beta, gamma)
                            for mul, l, p in Rs for _ in range(mul)])
示例#10
0
def ElementwiseTensorProduct(Rs_in1, Rs_in2, normalization='component'):
    Rs_in1 = o3.simplify(Rs_in1)
    Rs_in2 = o3.simplify(Rs_in2)

    assert sum(mul for mul, _, _ in Rs_in1) == sum(mul for mul, _, _ in Rs_in2)

    i = 0
    while i < len(Rs_in1):
        mul_1, l_1, p_1 = Rs_in1[i]
        mul_2, l_2, p_2 = Rs_in2[i]

        if mul_1 < mul_2:
            Rs_in2[i] = (mul_1, l_2, p_2)
            Rs_in2.insert(i + 1, (mul_2 - mul_1, l_2, p_2))

        if mul_2 < mul_1:
            Rs_in1[i] = (mul_2, l_1, p_1)
            Rs_in1.insert(i + 1, (mul_1 - mul_2, l_1, p_1))
        i += 1

    Rs_out = []
    instr = []
    for i, ((mul, l_1, p_1), (mul_2, l_2,
                              p_2)) in enumerate(zip(Rs_in1, Rs_in2)):
        assert mul == mul_2
        for l in list(range(abs(l_1 - l_2), l_1 + l_2 + 1)):
            i_out = len(Rs_out)
            Rs_out.append((mul, l, p_1 * p_2))
            instr += [(i, i, i_out, 'uuu', False)]

    return CustomWeightedTensorProduct(Rs_in1,
                                       Rs_in2,
                                       Rs_out,
                                       instr,
                                       normalization,
                                       own_weight=False)
示例#11
0
    def __init__(self, Rs_in, Rs_out, normalization: str = 'component'):
        super().__init__()

        self.Rs_in = o3.simplify(Rs_in)
        self.Rs_out = o3.simplify(Rs_out)

        instr = [(i_in, 0, i_out, 'uvw', True)
                 for i_in, (_, l_in, p_in) in enumerate(self.Rs_in)
                 for i_out, (_, l_out, p_out) in enumerate(self.Rs_out)
                 if l_in == l_out and p_in == p_out]
        self.tp = CustomWeightedTensorProduct(self.Rs_in, [(1, 0, 1)],
                                              self.Rs_out,
                                              instr,
                                              normalization,
                                              own_weight=True)

        output_mask = torch.cat([
            torch.ones(mul * (2 * l + 1)) if any(
                l_in == l and p_in == p
                for _, l_in, p_in in self.Rs_in) else torch.zeros(mul *
                                                                  (2 * l + 1))
            for mul, l, p in self.Rs_out
        ])
        self.register_buffer('output_mask', output_mask)
示例#12
0
    def __init__(self, Rs_scalars, act_scalars, Rs_gates, act_gates,
                 Rs_nonscalars):
        super().__init__()

        self.sc = Sortcut(Rs_scalars, Rs_gates)
        self.Rs_scalars, self.Rs_gates = self.sc.Rs_outs
        self.Rs_nonscalars = o3.simplify(Rs_nonscalars)
        self.Rs_in = self.sc.Rs_in + self.Rs_nonscalars

        self.act_scalars = Activation(Rs_scalars, act_scalars)
        Rs_scalars = self.act_scalars.Rs_out

        self.act_gates = Activation(Rs_gates, act_gates)
        Rs_gates = self.act_gates.Rs_out

        self.mul = nn.ElementwiseTensorProduct(Rs_nonscalars, Rs_gates)
        Rs_nonscalars = self.mul.Rs_out

        self.Rs_out = Rs_scalars + Rs_nonscalars
示例#13
0
def make_gated_block(Rs_in, muls, ps, Rs_sh):
    """
    Make a `GatedBlockParity` assuming many things
    """
    Rs_available = [(l, p_in * p_sh) for _, l_in, p_in in o3.simplify(Rs_in)
                    for _, l_sh, p_sh in Rs_sh
                    for l in range(abs(l_in - l_sh), l_in + l_sh + 1)]

    scalars = [(muls[0], 0, p) for p in ps if (0, p) in Rs_available]
    act_scalars = [(mul, swish if p == 1 else torch.tanh)
                   for mul, l, p in scalars]

    nonscalars = [(muls[l], l, p * (-1)**l) for l in range(1, len(muls))
                  for p in ps if (l, p * (-1)**l) in Rs_available]
    if (0, +1) in Rs_available:
        gates = [(o3.mul_dim(nonscalars), 0, +1)]
        act_gates = [(-1, torch.sigmoid)]
    else:
        gates = [(o3.mul_dim(nonscalars), 0, -1)]
        act_gates = [(-1, torch.tanh)]

    return GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars)
示例#14
0
    def __init__(self,
                 muls=(128, 8, 0),
                 ps=(1, -1),
                 lmax=1,
                 num_layers=6,
                 cutoff=10.0,
                 rad_gaussians=50,
                 rad_hs=(128, ),
                 num_neighbors=20,
                 readout='add',
                 dipole=False,
                 mean=None,
                 std=None,
                 scale=None,
                 atomref=None,
                 options="res"):
        super().__init__()

        assert readout in ['add', 'sum', 'mean']
        self.readout = readout
        self.cutoff = cutoff
        self.dipole = dipole
        self.mean = mean
        self.std = std
        self.scale = scale
        self.num_neighbors = num_neighbors
        self.options = options

        self.embedding = Embedding(100, muls[0])
        self.Rs_in = [(muls[0], 0, 1)]

        self.radial = GaussianRadialModel(rad_gaussians, cutoff)
        self.Rs_sh = [(1, l, (-1)**l) for l in range(lmax + 1)
                      ]  # spherical harmonics representation

        Rs = self.Rs_in
        modules = []
        for _ in range(num_layers):
            act = make_gated_block(Rs, muls, ps, self.Rs_sh)
            conv = Conv(Rs, act.Rs_in, self.Rs_sh, (rad_gaussians, ) + rad_hs)
            if 'res' in self.options:
                if Rs == act.Rs_out:
                    shortcut = Identity(Rs, act.Rs_out)
                else:
                    shortcut = Linear(Rs, act.Rs_out)
            else:
                shortcut = None

            Rs = o3.simplify(act.Rs_out)

            modules += [torch.nn.ModuleList([conv, act, shortcut])]

        self.layers = torch.nn.ModuleList(modules)

        self.Rs_out = [(1, 0, p) for p in ps]
        self.layers.append(
            Conv(Rs, self.Rs_out, self.Rs_sh, (rad_gaussians, ) + rad_hs))

        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)
示例#15
0
def reduce_tensor(formula, eps=1e-9, has_parity=None, **kw_Rs):
    """
    Usage
    Rs, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)])
    Rs = 0,2,4
    Q = tensor of shape [15, 81]
    """
    dtype = torch.get_default_dtype()
    with torch_default_dtype(torch.float64):
        # reformat `formulas` and make checks
        formulas = [
            (-1 if f.startswith('-') else 1, f.replace('-', ''))
            for f in formula.split('=')
        ]
        s0, f0 = formulas[0]
        assert s0 == 1

        for _s, f in formulas:
            if len(set(f)) != len(f) or set(f) != set(f0):
                raise RuntimeError(f'{f} is not a permutation of {f0}')
            if len(f0) != len(f):
                raise RuntimeError(f'{f0} and {f} don\'t have the same number of indices')

        # `formulas` is a list of (sign, permutation of indices)
        # each formula can be viewed as a permutation of the original formula
        formulas = {(s, tuple(f.index(i) for i in f0)) for s, f in formulas}  # set of generators (permutations)

        # they can be composed, for instance if you have ijk=jik=ikj
        # you also have ijk=jki
        # applying all possible compositions creates an entire group
        while True:
            n = len(formulas)
            formulas = formulas.union([(s, perm.inverse(p)) for s, p in formulas])
            formulas = formulas.union([
                (s1 * s2, perm.compose(p1, p2))
                for s1, p1 in formulas
                for s2, p2 in formulas
            ])
            if len(formulas) == n:
                break  # we break when the set is stable => it is now a group \o/

        # lets clean the `kw_Rs` before checking that they are compatible with the formulas
        for i in kw_Rs:
            if not callable(kw_Rs[i]):
                Rs = o3.convention(kw_Rs[i])
                if has_parity is None:
                    has_parity = any(p != 0 for _, _, p in Rs)
                if not has_parity and not all(p == 0 for _, _, p in Rs):
                    raise RuntimeError(f'{o3.format_Rs(Rs)} parity has to be specified everywhere or nowhere')
                if has_parity and any(p == 0 for _, _, p in Rs):
                    raise RuntimeError(f'{o3.format_Rs(Rs)} parity has to be specified everywhere or nowhere')
                kw_Rs[i] = Rs

        if has_parity is None:
            raise RuntimeError(f'please specify the argument `has_parity`')

        group = O3() if has_parity else SO3()

        # here we check that each index has one and only one representation
        for _s, p in formulas:
            f = "".join(f0[i] for i in p)
            for i, j in zip(f0, f):
                if i in kw_Rs and j in kw_Rs and kw_Rs[i] != kw_Rs[j]:
                    raise RuntimeError(f'Rs of {i} (Rs={o3.format_Rs(kw_Rs[i])}) and {j} (Rs={o3.format_Rs(kw_Rs[j])}) should be the same')
                if i in kw_Rs:
                    kw_Rs[j] = kw_Rs[i]
                if j in kw_Rs:
                    kw_Rs[i] = kw_Rs[j]

        for i in f0:
            if i not in kw_Rs:
                raise RuntimeError(f'index {i} has not Rs associated to it')

        ide = group.identity()
        dims = {i: len(kw_Rs[i](*ide)) if callable(kw_Rs[i]) else o3.dim(kw_Rs[i]) for i in f0}  # dimension of each index
        full_base = list(itertools.product(*(range(dims[i]) for i in f0)))  # (0, 0, 0), (0, 0, 1), (0, 0, 2), ... (3, 3, 3)
        # len(full_base) degrees of freedom in an unconstrained tensor

        # but there is constraints given by the group `formulas`
        # For instance if `ij=-ji`, then 00=-00, 01=-01 and so on
        base = set()
        for x in full_base:
            # T[x] is a coefficient of the tensor T and is related to other coefficient T[y]
            # if x and y are related by a formula
            xs = {(s, tuple(x[i] for i in p)) for s, p in formulas}
            # s * T[x] are all equal for all (s, x) in xs
            # if T[x] = -T[x] it is then equal to 0 and we lose this degree of freedom
            if not (-1, x) in xs:
                # the sign is arbitrary, put both possibilities
                base.add(frozenset({
                    frozenset(xs),
                    frozenset({(-s, x) for s, x in xs})
                }))

        # len(base) is the number of degrees of freedom in the tensor.
        # Now we want to decompose these degrees of freedom into irreps

        base = sorted([sorted([sorted(xs) for xs in x]) for x in base])  # requested for python 3.7 but not for 3.8 (probably a bug in 3.7)

        # First we compute the change of basis (projection) between full_base and base
        d_sym = len(base)
        d = len(full_base)
        Q = torch.zeros(d_sym, d)

        for i, x in enumerate(base):
            x = max(x, key=lambda xs: sum(s for s, x in xs))
            for s, e in x:
                j = full_base.index(e)
                Q[i, j] = s / len(x)**0.5

        assert torch.allclose(Q @ Q.T, torch.eye(d_sym))

        if d_sym == 0:
            return [], torch.zeros(d_sym, d).to(dtype=dtype)

        # We project the representation on the basis `base`
        def representation(g):
            def re(r):
                if callable(r):
                    return r(*g)
                return o3.rep(r, *g)

            m = kron(*(re(kw_Rs[i]) for i in f0))
            return Q @ m @ Q.T

        # And check that after this projection it is still a representation
        assert is_representation(group, representation, eps)

        # The rest of the code simply extract the irreps present in this representation
        Rs_out = []
        A = Q.clone()
        for r in group.irrep_indices():
            if group.irrep(r)(ide).shape[0] > d_sym - o3.dim(Rs_out):
                break

            mul, B, representation = reduce(group, representation, group.irrep(r), eps)
            A = direct_sum(torch.eye(d_sym - B.shape[0]), B) @ A
            A = _round_sqrt(A, eps)

            if has_parity:
                Rs_out += [(mul,) + r]
            else:
                Rs_out += [(mul, r)]

            if o3.dim(Rs_out) == d_sym:
                break

        if o3.dim(Rs_out) != d_sym:
            raise RuntimeError(f'unable to decompose into irreducible representations')

        return o3.simplify(Rs_out), A.to(dtype=dtype)