Esempio n. 1
0
    def forward(self, features_1, features_2, weight=None):
        """
        :return:         tensor [..., channel]
        """
        *size, n = features_1.size()
        features_1 = features_1.reshape(-1, n)
        assert n == rs.dim(self.Rs_in1), f"{n} is not {rs.dim(self.Rs_in1)}"
        *size2, n = features_2.size()
        features_2 = features_2.reshape(-1, n)
        assert n == rs.dim(self.Rs_in2), f"{n} is not {rs.dim(self.Rs_in2)}"
        assert size == size2

        if weight is None:
            weight = self.weight
        weight = weight.reshape(-1, self.nweight)
        if weight.shape[0] == 1:
            weight = weight.repeat(features_1.shape[0], 1)

        wigners = [getattr(self, arg) for arg in self.wigners_names]

        if features_1.shape[0] == 0:
            return torch.zeros(*size, rs.dim(self.Rs_out))

        features = self.main(*wigners, features_1, features_2, weight)
        return features.reshape(*size, -1)
Esempio n. 2
0
    def forward(self):
        """
        :return: tensor [..., l_out * mul_out * m_out, l_in * mul_in * m_in]
        """
        # (1) Case r > 0

        # use the radial model to fix all the degrees of freedom
        # note: for the normalization we assume that the variance of R[i] is one
        R = self.R(self.radii[self.radii > self.r_eps]
                   )  # [batch, l_out * l_in * mul_out * mul_in * l_filter]

        RY = rsh.mul_radial_angular(self.Rs_f, R, self.Y)

        if R.shape[0] == 0:
            kernel1 = torch.zeros(0, rs.dim(self.Rs_out), rs.dim(self.Rs_in))
        else:
            kernel1 = self.tp.right(RY)

        # (2) Case r = 0

        if self.linear is not None:
            kernel2 = self.linear()

            kernel = kernel1.new_zeros(len(self.radii), *kernel2.shape)
            kernel[self.radii > self.r_eps] = kernel1
            kernel[self.radii <= self.r_eps] = kernel2
        else:
            kernel = kernel1

        return kernel.reshape(*self.size, *kernel1.shape[1:])
Esempio n. 3
0
File: linear.py Progetto: zizai/e3nn
    def forward(self):
        """
        :return: tensor [l_out * mul_out * m_out, l_in * mul_in * m_in]
        """
        kernel = self.weight.new_zeros(rs.dim(self.Rs_out), rs.dim(self.Rs_in))
        begin_w = 0

        begin_out = 0
        for mul_out, l_out, p_out in self.Rs_out:
            s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
            begin_out += mul_out * (2 * l_out + 1)

            n_path = 0

            begin_in = 0
            for mul_in, l_in, p_in in self.Rs_in:
                s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
                begin_in += mul_in * (2 * l_in + 1)

                if (l_out, p_out) == (l_in, p_in):
                    weight = self.weight[begin_w: begin_w + mul_out * mul_in].reshape(mul_out, mul_in)  # [mul_out, mul_in]
                    begin_w += mul_out * mul_in

                    eye = torch.eye(2 * l_in + 1, dtype=self.weight.dtype, device=self.weight.device)
                    kernel[s_out, s_in] = torch.einsum('uv,ij->uivj', weight, eye).reshape(mul_out * (2 * l_out + 1), mul_in * (2 * l_in + 1))
                    n_path += mul_in

            if n_path > 0:
                kernel[s_out] /= math.sqrt(n_path)

        return kernel
Esempio n. 4
0
def kernel_fn_forward(Y, R, norm_coef, Rs_in, Rs_out, selection_rule,
                      set_of_l_filters):
    """
    :param Y: tensor [batch, l_filter * m_filter]
    :param R: tensor [batch, l_out * l_in * mul_out * mul_in * l_filter]
    :param norm_coef: tensor [l_out, l_in]
    :return: tensor [batch, l_out * mul_out * m_out, l_in * mul_in * m_in]
    """
    batch = Y.shape[0]
    n_in = rs.dim(Rs_in)
    n_out = rs.dim(Rs_out)

    kernel = Y.new_zeros(batch, n_out, n_in)

    # note: for the normalization we assume that the variance of R[i] is one
    begin_R = 0

    begin_out = 0
    for i, (mul_out, l_out, p_out) in enumerate(Rs_out):
        s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
        begin_out += mul_out * (2 * l_out + 1)

        begin_in = 0
        for j, (mul_in, l_in, p_in) in enumerate(Rs_in):
            s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
            begin_in += mul_in * (2 * l_in + 1)

            l_filters = selection_rule(l_in, p_in, l_out, p_out)
            if not l_filters:
                continue

            # extract the subset of the `R` that corresponds to the couple (l_out, l_in)
            n = mul_out * mul_in * len(l_filters)
            sub_R = R[:, begin_R:begin_R + n].reshape(
                batch, mul_out, mul_in,
                len(l_filters))  # [batch, mul_out, mul_in, l_filter]
            begin_R += n

            # note: I don't know if we can vectorize this for loop because [l_filter * m_filter] cannot be put into [l_filter, m_filter]
            K = 0
            for k, l_filter in enumerate(l_filters):
                tmp = sum(2 * l + 1 for l in set_of_l_filters if l < l_filter)
                sub_Y = Y[:, tmp:tmp + 2 * l_filter + 1]  # [batch, m]

                C = o3.wigner_3j(l_out,
                                 l_in,
                                 l_filter,
                                 cached=True,
                                 like=kernel)  # [m_out, m_in, m]

                # note: The multiplication with `sub_R` could also be done outside of the for loop
                K += norm_coef[i, j] * torch.einsum(
                    "ijk,zk,zuv->zuivj",
                    (C, sub_Y,
                     sub_R[..., k]))  # [batch, mul_out, m_out, mul_in, m_in]

            if not isinstance(K, int):
                kernel[:, s_out, s_in] = K.reshape_as(kernel[:, s_out, s_in])
    return kernel
Esempio n. 5
0
def test_tensor_product_to_dense():
    with o3.torch_default_dtype(torch.float64):
        Rs_1 = [(3, 0), (2, 1), (5, 2)]
        Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)]

        mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule)
        assert mul.to_dense().shape == (rs.dim(mul.Rs_out), rs.dim(Rs_1),
                                        rs.dim(Rs_2))
Esempio n. 6
0
 def test_change_lmax(self):
     lmax = 0
     mul = 1
     signal = torch.zeros(rs.dim([(mul, lmax)]))
     sph = sphten.SphericalTensor(signal, mul, lmax)
     lmax_new = 5
     sph_new = sph.change_lmax(lmax_new)
     assert sph_new.signal.shape[0] == rs.dim(sph_new.Rs)
Esempio n. 7
0
 def change_lmax(self, lmax):
     new_Rs = [(self.mul, l) for l in range(lmax + 1)]
     if self.lmax == lmax:
         return self
     elif self.lmax > lmax:
         new_signal = self.signal[:rs.dim(new_Rs)]
         return FourierTensor(new_signal, self.mul, lmax)
     elif self.lmax < lmax:
         new_signal = torch.zeros(rs.dim(new_Rs))
         new_signal[:rs.dim(self.Rs)] = self.signal
         return FourierTensor(new_signal, self.mul, lmax)
Esempio n. 8
0
 def change_lmax(self, lmax):
     new_Rs = [(1, l) for l in range(lmax + 1)]
     if self.lmax == lmax:
         return self
     elif self.lmax > lmax:
         new_signal = self.signal[..., :rs.dim(new_Rs)]
         return SphericalTensor(new_signal, self.p_val, self.p_arg)
     elif self.lmax < lmax:
         new_signal = torch.zeros(*self.signal.shape[:-1], rs.dim(new_Rs))
         new_signal[..., :rs.dim(self.Rs)] = self.signal
         return SphericalTensor(new_signal, self.p_val, self.p_arg)
Esempio n. 9
0
    def forward(self, features):
        """
        :param features: tensor [..., channel]
        :return:         tensor [..., channel]
        """
        size = features.shape[:-1]
        features = features.reshape(-1, rs.dim(self.Rs_in))

        output = torch.einsum('ij,zj->zi', self.kernel(), features)

        return output.reshape(*size, rs.dim(self.Rs_out))
Esempio n. 10
0
def test_group_kernel():
    kernel = partial(Kernel, RadialModel=ConstantRadialModel)
    Rs_in = [(5, 0, 1), (4, 1, -1)]
    Rs_out = [(3, 0, 1), (5, 1, -1)]
    groups = 4
    gkernel = GroupKernel(Rs_in, Rs_out, kernel, groups)

    N = 7
    input = torch.randn(N, 3)
    output = gkernel(input)
    assert output.dim() == 4  # [N, g, cout, cin]
    assert tuple(output.shape) == (N, groups, rs.dim(Rs_out), rs.dim(Rs_in))
Esempio n. 11
0
def initialize_edges(x,
                     Rs_in,
                     pos,
                     edge_index_dict,
                     lmax,
                     self_edge=1.,
                     symmetric_edges=False):
    """Initialize edge features of DataEdgeNeighbors using node features and SphericalTensor.

    Args:
        x (torch.tensor shape [N, rs.dim(Rs_in)]): Node features.
        Rs_in (rs.TY_RS_STRICT): Representation list of input.
        pos (torch.tensor shape [N, 3]): Cartesian coordinates of nodes.
        edge_index (torch.LongTensor shape [2, num_edges]): Edges described by index of node target then node source.
        lmax (int > 0): Maximum L to use for SphericalTensor projection of radial distance vectors
        self_edge (float, optional): L=0 feature for self edges. Defaults to 1.
        symmetric_edges (bool, optional): Constrain edge features to be symmetric in node index. Defaults to False

    Returns:
        edge_x: Edge features.
        Rs_edge (rs.TY_RS_STRICT): Representation list of edge features.
    """
    from e3nn.tensor import SphericalTensor
    edge_x = []
    if symmetric_edges:
        Rs, Q = rs.reduce_tensor('ij=ji', i=Rs_in)
    else:
        Rs, Q = rs.reduce_tensor('ij', i=Rs_in, j=Rs_in)
    Q = Q.reshape(-1, rs.dim(Rs_in), rs.dim(Rs_in))
    Rs_sph = [(1, l, (-1)**l) for l in range(lmax + 1)]
    tp_kernel = rs.TensorProduct(Rs, Rs_sph, o3.selection_rule)
    keys, values = list(zip(*edge_index_dict.items()))
    sorted_edges = sorted(zip(keys, values), key=lambda x: x[1])
    for (target, source), _ in sorted_edges:
        Ia = x[target]
        Ib = x[source]
        vector = (pos[source] - pos[target]).reshape(-1, 3)
        if torch.allclose(vector, torch.zeros(vector.shape)):
            signal = torch.zeros(rs.dim(Rs_sph))
            signal[0] = self_edge
        else:
            signal = SphericalTensor.from_geometry(vector, lmax=lmax).signal
            if symmetric_edges:
                signal += SphericalTensor.from_geometry(-vector,
                                                        lmax=lmax).signal
                signal *= 0.5
        output = torch.einsum('kij,i,j->k', Q, Ia, Ib)
        output = tp_kernel(output, signal)
        edge_x.append(output)
    edge_x = torch.stack(edge_x, dim=0)
    return edge_x, tp_kernel.Rs_out
Esempio n. 12
0
    def forward(self, x):
        for lin in self.layers:
            x = lin(x)

            x = x.reshape(*x.shape[:-1], self.mul,
                          -1)  # put multiplicity into batch
            x1 = x.narrow(-1, 0, rs.dim(self.act1.Rs_in))
            x2 = x.narrow(-1, rs.dim(self.act1.Rs_in), rs.dim(self.act2.Rs_in))
            x1 = self.act1(x1)
            x2 = self.act2(x2)
            x = torch.cat([x1, x2], dim=-1)
            x = x.reshape(*x.shape[:-2], -1)  # put back into representation

        x = self.tail(x)
        return x
Esempio n. 13
0
    def forward(self, features_1, features_2):
        """
        :return:         tensor [..., channel]
        """
        *size, n = features_1.size()
        features_1 = features_1.reshape(-1, n)
        assert n == rs.dim(self.Rs_in1)
        *size2, n = features_2.size()
        features_2 = features_2.reshape(-1, n)
        assert size == size2

        T = get_sparse_buffer(self, 'T')  # [out, in1 * in2]
        kernel = (T.t() @ self.kernel().T).T.reshape(rs.dim(self.Rs_out), rs.dim(self.Rs_in1), rs.dim(self.Rs_in2))  # [out, in1, in2]
        features = torch.einsum('kij,zi,zj->zk', kernel, features_1, features_2)
        return features.reshape(*size, -1)
Esempio n. 14
0
def test_tensor_product_symmetry():
    with o3.torch_default_dtype(torch.float64):
        Rs_in = [(3, 0), (2, 1), (5, 2)]
        Rs_out = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)]

        mul1 = rs.TensorProduct(Rs_in, o3.selection_rule, Rs_out)
        mul2 = rs.TensorProduct(o3.selection_rule, Rs_in, Rs_out)

        assert mul1.Rs_in2 == mul2.Rs_in1

        x = torch.randn(rs.dim(Rs_in), rs.dim(mul1.Rs_in2))
        y1 = mul1(x)
        y2 = mul2(x.T)

        assert (y1 - y2).abs().max() < 1e-10
Esempio n. 15
0
    def forward(self, r, r_eps=0, **_kwargs):
        """
        :param r: tensor [..., 3]
        :return: tensor [..., l_out * mul_out * m_out, l_in * mul_in * m_in]
        """
        *size, xyz = r.size()
        assert xyz == 3
        r = r.reshape(-1, 3)

        radii = r.norm(2, dim=1)  # [batch]

        # (1) Case r > 0

        # precompute all needed spherical harmonics
        Y = rsh.spherical_harmonics_xyz(
            self.Ls, r[radii > r_eps])  # [batch, l_filter * m_filter]

        # Normalize the spherical harmonics
        if self.normalization == 'component':
            Y.mul_(math.sqrt(4 * math.pi))
        if self.normalization == 'norm':
            diag = math.sqrt(4 * math.pi) * torch.cat([
                torch.ones(2 * l + 1) / math.sqrt(2 * l + 1)
                for _, l, _ in self.Rs_f
            ])
            Y.mul_(diag)

        # use the radial model to fix all the degrees of freedom
        # note: for the normalization we assume that the variance of R[i] is one
        R = self.R(radii[radii > r_eps]
                   )  # [batch, l_out * l_in * mul_out * mul_in * l_filter]

        RY = rsh.mul_radial_angular(self.Rs_f, R, Y)

        if Y.shape[0] == 0:
            kernel1 = torch.zeros(0, rs.dim(self.Rs_out), rs.dim(self.Rs_in))
        else:
            kernel1 = self.tp.right(RY)

        # (2) Case r = 0

        kernel2 = self.linear()

        kernel = r.new_zeros(len(r), *kernel2.shape)
        kernel[radii > r_eps] = kernel1
        kernel[radii <= r_eps] = kernel2

        return kernel.reshape(*size, *kernel2.shape)
Esempio n. 16
0
    def rotation_gated_block(self, K):
        """Test rotation equivariance on GatedBlock and dependencies."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            K = partial(K, RadialModel=ConstantRadialModel)

            act = GatedBlock(Rs_out,
                             scalar_activation=sigmoid,
                             gate_activation=sigmoid)
            conv = Convolution(K(Rs_in, act.Rs_in))

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            fea = torch.randn(1, 4, rs.dim(Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Esempio n. 17
0
    def parity_rotation_gated_block_parity(self, K):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]]

            K = partial(K, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K(Rs_in, act.Rs_in))

            abc = torch.randn(3)
            rot_geo = -o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc, 1)
            D_out = rs.rep(act.Rs_out, *abc, 1)

            fea = torch.randn(1, 4, rs.dim(Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Esempio n. 18
0
    def forward(self, features):
        """
        :param features: tensor [..., channel]
        :return:         tensor [..., channel]
        """
        *size, d = features.shape
        assert d == rs.dim(self.Rs)

        norms = self.norm(features)  # [..., l*mul]

        output = []
        index_features = 0
        index_norms = 0
        for mul, l, _ in self.Rs:
            v = features.narrow(-1, index_features, mul * (2 * l + 1)).reshape(*size, mul, 2 * l + 1)  # [..., u, i]
            index_features += mul * (2 * l + 1)

            n = norms.narrow(-1, index_norms, mul).reshape(*size, mul, 1)  # [..., u, i]
            b = self.bias[index_norms: index_norms + mul].reshape(mul, 1)  # [u, i]
            index_norms += mul

            if l == 0:
                out = self.activation(v + b)
            else:
                out = self.activation(n + b) * v

            output.append(out.reshape(*size, mul * (2 * l + 1)))

        return torch.cat(output, dim=-1)
Esempio n. 19
0
 def __init__(self, tensor, Rs):
     Rs = rs.convention(Rs)
     if tensor.shape[-1] != rs.dim(Rs):
         raise ValueError(
             "Last tensor dimension and Rs do not have same dimension.")
     self.tensor = tensor
     self.Rs = Rs
Esempio n. 20
0
def plot_data_on_grid(box_length,
                      radial,
                      Rs,
                      sh=o3.spherical_harmonics_xyz,
                      n=30):
    L_to_index = {}
    set_of_L = set([L for mul, L in Rs])
    start = 0
    for L in set_of_L:
        L_to_index[L] = [start, start + 2 * L + 1]
        start += 2 * L + 1

    r = np.mgrid[-1:1:n * 1j, -1:1:n * 1j, -1:1:n * 1j].reshape(3, -1)
    r = r.transpose(1, 0)
    r *= box_length / 2.
    r = torch.from_numpy(r)
    Ys = sh(set_of_L, r)
    R = radial(r.norm(2, -1)).detach()  # [r_values, n_r_filters]
    assert R.shape[-1] == rs.mul_dim(Rs)

    R_helper = torch.zeros(R.shape[-1], rs.dim(Rs))
    mul_start = 0
    y_start = 0
    Ys_indices = []
    for mul, L in Rs:
        Ys_indices += list(range(L_to_index[L][0], L_to_index[L][1])) * mul

    R_helper = rs.map_mul_to_Rs(Rs)

    full_Ys = Ys[Ys_indices]  # [values, rs.dim(Rs)]]
    full_Ys = full_Ys.reshape(full_Ys.shape[0], -1)

    all_f = torch.einsum('xn,dn,dx->xd', R, R_helper, full_Ys)
    return r, all_f
Esempio n. 21
0
def main():
    representations1 = [(1, ), (3, 4, 0, 0), (8, 8, 0, 0), (8, 6, 0, 0),
                        (64, )]
    representations1 = [[(mul, l) for l, mul in enumerate(rs)]
                        for rs in representations1]
    representations2 = [(1, ), (2, 3, 2, 0), (6, 5, 5, 0), (6, 4, 4, 0),
                        (64, )]
    representations2 = [[(mul, l) for l, mul in enumerate(rs)]
                        for rs in representations2]
    representations3 = [(1, ), (2, 2, 2, 1), (4, 4, 4, 4), (6, 4, 4, 0),
                        (64, )]
    representations3 = [[(mul, l) for l, mul in enumerate(rs)]
                        for rs in representations3]
    representations0 = [[
        (mul, 0)
    ] for l, mul in enumerate([dim(r) for r in representations3])]

    tetris, labels = get_dataset()
    data = []
    for i, reps in enumerate([
            representations0, representations1, representations2,
            representations3
    ]):
        f = SE3Net(len(tetris), reps)
        training, _ = train(tetris, labels, f)
        data.append(training)
    return data
Esempio n. 22
0
def check_rotation(batch: int = 10, n_atoms: int = 25):
    # Setup the network.
    K = partial(Kernel, RadialModel=ConstantRadialModel)
    Rs_in = [(1, 0), (1, 1)]
    Rs_out = [(1, 0), (1, 1), (1, 2)]
    act = GatedBlock(
        Rs_out,
        scalar_activation=sigmoid,
        gate_activation=absolute,
    )
    conv = Convolution(K, Rs_in, act.Rs_in)

    # Setup the data. The geometry, input features, and output features must all rotate.
    abc = torch.randn(3)  # Rotation seed of euler angles.
    rot_geo = o3.rot(*abc)
    D_in = rs.rep(Rs_in, *abc)
    D_out = rs.rep(Rs_out, *abc)

    feat = torch.randn(batch, n_atoms, rs.dim(Rs_in))  # Transforms with wigner D matrix
    geo = torch.randn(batch, n_atoms, 3)  # Transforms with rotation matrix.

    # Test equivariance.
    F = act(conv(feat, geo))
    RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t()))
    return (RF - FR).norm() < 10e-5 * RF.norm()
Esempio n. 23
0
    def test1(self):
        torch.set_default_dtype(torch.float64)
        Rs_in = [(5, 0), (20, 1), (15, 0), (20, 2)]
        Rs_out = [(5, 0), (10, 1), (10, 2), (5, 0)]

        with torch.no_grad():
            lin = Linear(Rs_in, Rs_out)
            features = torch.randn(10000, rs.dim(Rs_in))
            features = lin(features)

        bins, left, right = 100, -4, 4
        bin_width = (right - left) / (bins - 1)
        x = torch.linspace(left, right, bins)
        p = torch.histc(features, bins, left,
                        right) / features.numel() / bin_width
        q = x.pow(2).div(-2).exp().div(math.sqrt(2 * math.pi))  # Normal law

        # import matplotlib.pyplot as plt
        # plt.plot(x, p)
        # plt.plot(x, q)
        # plt.show()

        Dkl = ((p + 1e-100) /
               q).log().mul(p).sum()  # Kullback-Leibler divergence of P || Q
        self.assertLess(Dkl, 0.1)
Esempio n. 24
0
def kernel_conv_fn_forward(F, Y, R, norm_coef, Rs_in, Rs_out, selection_rule,
                           set_of_l_filters):
    """
    :param F: tensor [batch, b, l_in * mul_in * m_in]
    :param Y: tensor [l_filter * m_filter, batch, a, b]
    :param R: tensor [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter]
    :param norm_coef: tensor [l_out, l_in]
    :return: tensor [batch, a, l_out * mul_out * m_out, l_in * mul_in * m_in]
    """
    batch, a, b, _ = Y.shape
    n_out = rs.dim(Rs_out)

    kernel_conv = Y.new_zeros(batch, a, n_out)

    # note: for the normalization we assume that the variance of R[i] is one
    begin_R = 0

    begin_out = 0
    for i, (mul_out, l_out, p_out) in enumerate(Rs_out):
        s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
        begin_out += mul_out * (2 * l_out + 1)

        begin_in = 0
        for j, (mul_in, l_in, p_in) in enumerate(Rs_in):
            s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
            begin_in += mul_in * (2 * l_in + 1)

            l_filters = selection_rule(l_in, p_in, l_out, p_out)
            if not l_filters:
                continue

            # extract the subset of the `R` that corresponds to the couple (l_out, l_in)
            n = mul_out * mul_in * len(l_filters)
            sub_R = R[:, :, :, begin_R:begin_R + n].reshape(
                batch, a, b, mul_out, mul_in,
                -1)  # [batch, a, b, mul_out, mul_in, l_filter]
            begin_R += n

            K = 0
            for k, l_filter in enumerate(l_filters):
                offset = sum(2 * l + 1 for l in set_of_l_filters
                             if l < l_filter)
                sub_Y = Y[...,
                          offset:offset + 2 * l_filter + 1]  # [batch, a, b, m]

                C = o3.wigner_3j(l_out,
                                 l_in,
                                 l_filter,
                                 cached=True,
                                 like=kernel_conv)  # [m_out, m_in, m]

                K += norm_coef[i, j] * torch.einsum(
                    "ijk,zabk,zabuv,zbvj->zaui", C,
                    sub_Y, sub_R[..., k], F[..., s_in].reshape(
                        batch, b, mul_in, -1))  # [batch, a, mul_out, m_out]

            if not isinstance(K, int):
                kernel_conv[:, :, s_out] += K.reshape(batch, a, -1)

    return kernel_conv
Esempio n. 25
0
def check_rotation_parity(batch: int = 10, n_atoms: int = 25):
    # Setup the network.
    K = partial(Kernel, RadialModel=ConstantRadialModel)
    Rs_in = [(1, 0, +1)]
    act = GatedBlockParity(
        Rs_scalars=[(4, 0, +1)],
        act_scalars=[(-1, relu)],
        Rs_gates=[(8, 0, +1)],
        act_gates=[(-1, tanh)],
        Rs_nonscalars=[(4, 1, -1), (4, 2, +1)]
    )
    conv = Convolution(K, Rs_in, act.Rs_in)
    Rs_out = act.Rs_out

    # Setup the data. The geometry, input features, and output features must all rotate and observe parity.
    abc = torch.randn(3)  # Rotation seed of euler angles.
    rot_geo = -o3.rot(*abc)  # Negative because geometry has odd parity. i.e. improper rotation.
    D_in = rs.rep(Rs_in, *abc, parity=1)
    D_out = rs.rep(Rs_out, *abc, parity=1)

    feat = torch.randn(batch, n_atoms, rs.dim(Rs_in))  # Transforms with wigner D matrix and parity.
    geo = torch.randn(batch, n_atoms, 3)  # Transforms with rotation matrix and parity.

    # Test equivariance.
    F = act(conv(feat, geo))
    RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t()))
    return (RF - FR).norm() < 10e-5 * RF.norm()
Esempio n. 26
0
    def forward(self, features):
        '''
        :param features: [..., channels]
        '''
        *size, n = features.size()
        features = features.reshape(-1, n)
        assert n == rs.dim(self.Rs_in)

        if self.linear:
            features = torch.cat([features.new_ones(features.shape[0], 1), features], dim=1)
            n += 1

        T = get_sparse_buffer(self, 'T')  # [out, in1 * in2]
        kernel = (T.t() @ self.kernel().T).T.reshape(rs.dim(self.Rs_out), n, n)  # [out, in1, in2]
        features = torch.einsum('zi,zj->zij', features, features)
        features = torch.einsum('kij,zij->zk', kernel, features)
        return features.reshape(*size, -1)
Esempio n. 27
0
    def test_tensor_product(self):
        torch.set_default_dtype(torch.float64)

        Rs_1 = [(3, 0), (2, 1), (5, 2)]
        Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)]

        Rs_out, m = rs.tensor_product(Rs_1, Rs_2, o3.selection_rule)
        mul = TensorProduct(Rs_1, Rs_2)

        x1 = torch.randn(1, rs.dim(Rs_1))
        x2 = torch.randn(1, rs.dim(Rs_2))

        y1 = mul(x1, x2)
        y2 = torch.einsum('kij,zi,zj->zk', m, x1, x2)

        self.assertEqual(rs.dim(Rs_out), y1.shape[1])
        self.assertLess((y1 - y2).abs().max(), 1e-7 * y1.abs().max())
Esempio n. 28
0
    def forward(self, features_1, features_2, weights):
        """
        :return:         tensor [..., channel]
        """
        *size, n = features_1.size()
        features_1 = features_1.reshape(-1, n)
        assert n == rs.dim(self.Rs_in1), f"{n} is not {rs.dim(self.Rs_in1)}"
        *size2, n = features_2.size()
        features_2 = features_2.reshape(-1, n)
        assert n == rs.dim(self.Rs_in2), f"{n} is not {rs.dim(self.Rs_in2)}"
        assert size == size2
        weights = weights.reshape(-1, self.nweight)

        wigners = [getattr(self, arg) for arg in self.wigners_names]

        features = self.main(*wigners, features_1, features_2, weights)
        return features.reshape(*size, -1)
Esempio n. 29
0
        def test(Rs, ac):
            x = torch.randn(99, rs.dim(Rs))
            a, b = torch.rand(2)
            c = 1

            y1 = ac(x, dim=-1) @ rs.rep(ac.Rs_out, a, b, c).T
            y2 = ac(x @ rs.rep(Rs, a, b, c).T, dim=-1)
            y3 = ac(x @ rs.rep(Rs, -c, -b, -a).T, dim=-1)
            self.assertLess((y1 - y2).norm(), (y1 - y3).norm())
Esempio n. 30
0
    def forward(self, x):
        """
        :param x: [batch, x, y, z, channel_in]
        :return: [batch, x, y, z, channel_out]
        """
        for conv, act, pool in self.layers:
            x = conv(x)

            x = x.reshape(*x.shape[:-1], self.mul,
                          rs.dim(act.Rs_in))  # put multiplicity into batch
            x = act(x)
            x = x.reshape(*x.shape[:-2], self.mul *
                          rs.dim(act.Rs_out))  # put back into representation

            x = pool(x)

        x = self.tail(x)
        return x