예제 #1
0
def test_sh_parity():
    """
    (-1)^l Y(x) = Y(-x)
    """
    with o3.torch_default_dtype(torch.float64):
        for l in range(7 + 1):
            x = torch.randn(3)
            Y1 = (-1) ** l * rsh.spherical_harmonics_xyz([l], x)
            Y2 = rsh.spherical_harmonics_xyz([l], -x)
            assert (Y1 - Y2).abs().max() < 1e-10 * Y1.abs().max()
예제 #2
0
def test_sh_cuda_single():
    if torch.cuda.is_available():
        with o3.torch_default_dtype(torch.float64):
            for l in range(10 + 1):
                x = torch.randn(10, 3)
                x_cuda = x.cuda()
                Y1 = rsh.spherical_harmonics_xyz([l], x)
                Y2 = rsh.spherical_harmonics_xyz([l], x_cuda).cpu()
                assert (Y1 - Y2).abs().max() < 1e-7
    else:
        print("Cuda is not available! test_sh_cuda_single skipped!")
예제 #3
0
def test_sh_cuda_ordered_partial():
    if torch.cuda.is_available():
        with o3.torch_default_dtype(torch.float64):
            l = [0, 2, 5, 7, 10]
            x = torch.randn(10, 3)
            x_cuda = x.cuda()
            Y1 = rsh.spherical_harmonics_xyz(l, x)
            Y2 = rsh.spherical_harmonics_xyz(l, x_cuda).cpu()
            assert (Y1 - Y2).abs().max() < 1e-7
    else:
        print("Cuda is not available! test_sh_cuda_ordered_partial skipped!")
예제 #4
0
파일: o3_test.py 프로젝트: zizai/e3nn
 def test_sh_cuda_ordered_full(self):
     if torch.cuda.is_available():
         with o3.torch_default_dtype(torch.float64):
             l = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
             x = torch.randn(10, 3)
             x_cuda = x.cuda()
             Y1 = rsh.spherical_harmonics_xyz(l, x)
             Y2 = rsh.spherical_harmonics_xyz(l, x_cuda).cpu()
             self.assertLess((Y1 - Y2).abs().max(), 1e-7)
     else:
         print("Cuda is not available! test_sh_cuda_ordered_full skipped!")
예제 #5
0
def test_sh_norm():
    with o3.torch_default_dtype(torch.float64):
        l_filter = list(range(15))
        Ys = [rsh.spherical_harmonics_xyz([l], torch.randn(10, 3)) for l in l_filter]
        s = torch.stack([Y.pow(2).mean(-1) for Y in Ys])
        d = s - 1 / (4 * math.pi)
        assert d.pow(2).mean().sqrt() < 1e-10

        n = rsh.spherical_harmonics_xyz(3, torch.randn(3), 'norm').norm()
        assert abs(n - 1) < 1e-10

        n = rsh.spherical_harmonics_xyz(3, torch.randn(3), 'component').norm()
        assert abs(n - 7**0.5) < 1e-10
예제 #6
0
파일: rsh_test.py 프로젝트: mitwood/e3nn
def test_rsh_backwardable():
    lmax = 10
    Rs = [(1, l) for l in range(lmax + 1)]

    xyz = torch.tensor([0., 0., 1.], requires_grad=True)
    sph = rsh.spherical_harmonics_xyz(Rs, xyz, eps=0)
    sph.norm(2, -1).mean().backward()
    assert torch.allclose(
        torch.isnan(xyz.grad).nonzero(), torch.LongTensor([[0], [1], [2]]))

    xyz = torch.tensor([0., 0., 1.], requires_grad=True)
    sph = rsh.spherical_harmonics_xyz(Rs, xyz, eps=1e-10)
    sph.norm(2, -1).mean().backward()
    assert torch.allclose(
        torch.isnan(xyz.grad).nonzero(), torch.LongTensor([[]]))
예제 #7
0
def spherical_harmonics_dirac(vectors, lmax):
    """
    approximation of a signal that is 0 everywhere except on the angle (alpha, beta) where it is one.
    the higher is lmax the better is the approximation
    """
    return 4 * math.pi / (lmax + 1)**2 * rsh.spherical_harmonics_xyz(
        list(range(lmax + 1)), vectors)
예제 #8
0
    def forward(self, r, r_eps=0, custom_backward=False):
        """
        :param r: tensor [..., 3]
        :param custom_backward: call KernelFn rather than using automatic differentiation
        :return: tensor [..., l_out * mul_out * m_out, l_in * mul_in * m_in]
        """
        *size, xyz = r.size()
        assert xyz == 3
        r = r.reshape(-1, 3)

        radii = r.norm(2, dim=1)  # [batch]

        # (1) Case r > 0

        # precompute all needed spherical harmonics
        Y = rsh.spherical_harmonics_xyz(self.set_of_l_filters, r[radii > r_eps])  # [batch, l_filter * m_filter]

        # use the radial model to fix all the degrees of freedom
        # note: for the normalization we assume that the variance of R[i] is one
        R = self.R(radii[radii > r_eps])  # [batch, l_out * l_in * mul_out * mul_in * l_filter]

        if custom_backward:
            kernel1 = KernelFn.apply(Y, R, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters)
        else:
            kernel1 = kernel_fn_forward(Y, R, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters)

        # (2) Case r = 0

        kernel2 = self.linear()

        kernel = r.new_zeros(len(r), *kernel2.shape)
        kernel[radii > r_eps] = kernel1
        kernel[radii <= r_eps] = kernel2

        return kernel.reshape(*size, *kernel2.shape)
예제 #9
0
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 RadialModel,
                 r,
                 r_eps=0,
                 selection_rule=o3.selection_rule_in_out_sh,
                 normalization='component'):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        :param RadialModel: Class(d), trainable model: R -> R^d
        :param tensor r: [..., 3]
        :param float r_eps: distance considered as zero
        :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter]
        :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...]
        :param normalization: either 'norm' or 'component'
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()

        self.Rs_in = rs.convention(Rs_in)
        self.Rs_out = rs.convention(Rs_out)
        self.check_input_output(selection_rule)

        *self.size, xyz = r.size()
        assert xyz == 3
        r = r.reshape(-1, 3)  # [batch, space]
        self.register_buffer('radii', r.norm(2, dim=1))  # [batch]
        self.r_eps = r_eps

        self.tp = rs.TensorProduct(self.Rs_in,
                                   selection_rule,
                                   self.Rs_out,
                                   normalization,
                                   sorted=True)
        self.Rs_f = self.tp.Rs_in2

        Y = rsh.spherical_harmonics_xyz(
            [(1, l, p) for _, l, p in self.Rs_f],
            r[self.radii > self.r_eps])  # [batch, l_filter * m_filter]

        # Normalize the spherical harmonics
        if normalization == 'component':
            Y.mul_(math.sqrt(4 * math.pi))
        if normalization == 'norm':
            diag = math.sqrt(4 * math.pi) * torch.cat([
                torch.ones(2 * l + 1) / math.sqrt(2 * l + 1)
                for _, l, _ in self.Rs_f
            ])
            Y.mul_(diag)

        self.register_buffer('Y', Y)
        self.R = RadialModel(rs.mul_dim(self.Rs_f))

        if (self.radii <= self.r_eps).any():
            self.linear = KernelLinear(self.Rs_in, self.Rs_out)
        else:
            self.linear = None
예제 #10
0
    def forward(self,
                features,
                edge_index,
                edge_r,
                sh=None,
                size=None,
                n_norm=1):
        """
        :param features: Tensor of shape [n_target, dim(Rs_in)]
        :param edge_index: LongTensor of shape [2, num_messages]
                           edge_index[0] = sources (convolution centers)
                           edge_index[1] = targets (neighbors)
        :param edge_r: Tensor of shape [num_messages, 3]
                       edge_r = position_target - position_source
        :param sh: Tensor of shape [num_messages, dim(Rs_sh)]
        :param size: (n_target, n_source) or None
        :param n_norm: typical number of targets per source

        :return: Tensor of shape [n_source, dim(Rs_out)]
        """
        if sh is None:
            sh = rsh.spherical_harmonics_xyz(
                self.Rs_sh, edge_r,
                self.normalization)  # [num_messages, dim(Rs_sh)]
        sh = sh / n_norm**0.5

        w = self.rm(edge_r.norm(dim=1))  # [num_messages, nweight]

        return self.propagate(edge_index, size=size, x=features, sh=sh, w=w)
예제 #11
0
    def test1(self):
        """test gradients of the Kernel"""
        torch.set_default_dtype(torch.float64)
        Rs_in = [(1, 0), (1, 1), (1, 0), (1, 2)]
        Rs_out = [(1, 0), (1, 1), (1, 2), (1, 0)]
        kernel = Kernel(Rs_in, Rs_out, ConstantRadialModel,
                        partial(o3.selection_rule_in_out_sh, lmax=1))

        n_path = 0
        for mul_out, l_out, p_out in kernel.Rs_out:
            for mul_in, l_in, p_in in kernel.Rs_in:
                l_filters = kernel.selection_rule(l_in, p_in, l_out, p_out)
                n_path += mul_out * mul_in * len(l_filters)

        r = torch.randn(2, 3)
        Y = rsh.spherical_harmonics_xyz(kernel.set_of_l_filters,
                                        r)  # [l_filter * m_filter, batch]
        Y = Y.clone().detach().requires_grad_(True)
        R = torch.randn(
            2, n_path, requires_grad=True
        )  # [batch, l_out * l_in * mul_out * mul_in * l_filter]

        inputs = (Y, R, kernel.norm_coef, kernel.Rs_in, kernel.Rs_out,
                  kernel.selection_rule, kernel.set_of_l_filters)
        self.assertTrue(torch.autograd.gradcheck(KernelFn.apply, inputs))
예제 #12
0
 def signal_xyz(self, r):
     """
     Evaluate the signal on the sphere
     """
     sh = rsh.spherical_harmonics_xyz(list(range(self.lmax + 1)), r)
     dim = (self.lmax + 1)**2
     output = torch.einsum('ai,zi->za', sh.reshape(-1, dim),
                           self.signal.reshape(-1, dim))
     return output.reshape((*self.signal.shape[:-1], *r.shape[:-1]))
예제 #13
0
def test_wigner_3j_sh_norm():
    with o3.torch_default_dtype(torch.float64):
        for l_out in range(3 + 1):
            for l_in in range(l_out, 4 + 1):
                for l_f in range(abs(l_out - l_in), l_out + l_in + 1):
                    Q = o3.wigner_3j(l_out, l_in, l_f)
                    Y = rsh.spherical_harmonics_xyz([l_f], torch.randn(3))
                    QY = math.sqrt(4 * math.pi) * Q @ Y
                    assert abs(QY.norm() - 1) < 1e-10
예제 #14
0
파일: o3_test.py 프로젝트: zizai/e3nn
 def test_sh_norm(self):
     with o3.torch_default_dtype(torch.float64):
         l_filter = list(range(15))
         Ys = [
             rsh.spherical_harmonics_xyz([l], torch.randn(10, 3))
             for l in l_filter
         ]
         s = torch.stack([Y.pow(2).mean(-1) for Y in Ys])
         d = s - 1 / (4 * math.pi)
         self.assertLess(d.pow(2).mean().sqrt(), 1e-10)
예제 #15
0
def forward(f, shapes, labels, lmax, device):
    r_max = 1.1
    x = torch.ones(4, 1)
    batch = Batch.from_data_list([DataNeighbors(x, shape, r_max, y=label, self_interaction=False) for shape, label in zip(shapes, labels)])
    batch = batch.to(device)
    sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), batch.edge_attr, 'component')
    out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3)
    out = scatter_add(out, batch.batch, dim=0)
    out = torch.tanh(out)
    return out
예제 #16
0
파일: convolution.py 프로젝트: wudangt/e3nn
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 size,
                 steps=(1, 1, 1),
                 lmax=None,
                 fuzzy_pixels=False,
                 allow_unused_inputs=False,
                 allow_zero_outputs=False,
                 **kwargs):
        super().__init__()

        r = torch.linspace(-1, 1, size)
        x = r * steps[0] / min(steps)
        x = x[x.abs() <= 1]
        y = r * steps[1] / min(steps)
        y = y[y.abs() <= 1]
        z = r * steps[2] / min(steps)
        z = z[z.abs() <= 1]
        r = torch.stack(torch.meshgrid(x, y, z), dim=-1)  # [x, y, z, R^3]

        R = partial(CosineBasisModel,
                    max_radius=1.0,
                    number_of_basis=(size + 1) // 2,
                    h=50,
                    L=3,
                    act=swish)
        self.kernel = FrozenKernel(
            Rs_in,
            Rs_out,
            R,
            r,
            selection_rule=partial(o3.selection_rule_in_out_sh, lmax=lmax),
            normalization='component',
            allow_unused_inputs=allow_unused_inputs,
            allow_zero_outputs=allow_zero_outputs,
        )
        self.kwargs = kwargs

        if fuzzy_pixels:
            # re-evaluate spherical harmonics by adding randomness
            r = r.reshape(-1, 3)
            r = r[self.kernel.radii > 0]
            rand = torch.rand(20**3, *r.shape).mul(2).sub(1)  # [-1, 1]
            rand.mul_(1 / (size - 1))
            rand[:, :, 0].mul_(steps[0] / min(steps))
            rand[:, :, 1].mul_(steps[1] / min(steps))
            rand[:, :, 2].mul_(steps[2] / min(steps))
            r = rand + r.unsqueeze(0)  # [rand, batch, R^3]
            Y = rsh.spherical_harmonics_xyz([(1, l, p)
                                             for _, l, p in self.kernel.Rs_f],
                                            r)
            # Y  # [rand, batch, l_filter * m_filter]
            Y.mul_(math.sqrt(4 * math.pi))  # normalization='component'
            self.kernel.Y.copy_(Y.mean(0))
예제 #17
0
    def forward(self,
                features,
                difference_geometry,
                mask,
                y=None,
                radii=None,
                custom_backward=True):
        """
        :param features: tensor [batch, b, l_in * mul_in * m_in]
        :param difference_geometry: tensor [batch, a, b, xyz]
        :param mask:     tensor [batch, a] (In order to zero contributions from padded atoms.)
        :param y:        Optional precomputed spherical harmonics.
        :param radii:    Optional precomputed normed geometry.
        :param custom_backward: call KernelConvFn rather than using automatic differentiation, (default True)
        :return:         tensor [batch, a, l_out * mul_out * m_out]
        """
        _batch, _a, _b, xyz = difference_geometry.size()
        assert xyz == 3

        if radii is None:
            radii = difference_geometry.norm(2, dim=-1)  # [batch, a, b]

        # precompute all needed spherical harmonics
        if y is None:
            y = rsh.spherical_harmonics_xyz(
                self.set_of_l_filters,
                difference_geometry)  # [batch, a, b, l_filter * m_filter]

        y[radii == 0] = 0

        # use the radial model to fix all the degrees of freedom
        # note: for the normalization we assume that the variance of R[i] is one
        r = self.R(radii.flatten()).reshape(
            *radii.shape,
            -1)  # [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter]
        r = r.clone()
        r[radii == 0] = 0

        if custom_backward:
            output = KernelConvFn.apply(features, y, r, self.norm_coef,
                                        self.Rs_in, self.Rs_out,
                                        self.selection_rule,
                                        self.set_of_l_filters)
        else:
            output = kernel_conv_fn_forward(features, y, r, self.norm_coef,
                                            self.Rs_in, self.Rs_out,
                                            self.selection_rule,
                                            self.set_of_l_filters)

        # Case r > 0
        if radii.shape[1] == radii.shape[2]:
            output += torch.einsum('ij,zaj->zai', self.linear(), features)

        return output * mask.unsqueeze(-1)
예제 #18
0
def forward(f, shapes, Rs_sh, device):
    r_max = 1.1
    x = torch.ones(4, 1)
    batch = Batch.from_data_list([DataNeighbors(x, shape, r_max, self_interaction=False) for shape in shapes])
    batch = batch.to(device)
    # Pre-compute the spherical harmonics and re-use them in each convolution
    sh = rsh.spherical_harmonics_xyz(Rs_sh, batch.edge_attr, 'component')
    out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3)
    out = scatter_add(out, batch.batch, dim=0)
    out = torch.tanh(out)
    return out
예제 #19
0
    def test_irr_repr_wigner_3j(self):
        """Test irr_repr and wigner_3j equivariance."""
        with torch_default_dtype(torch.float64):
            l_in = 3
            l_out = 2

            for l_f in range(abs(l_in - l_out), l_in + l_out + 1):
                r = torch.randn(100, 3)
                Q = o3.wigner_3j(l_out, l_in, l_f)

                abc = torch.randn(3)
                D_in = o3.irr_repr(l_in, *abc)
                D_out = o3.irr_repr(l_out, *abc)

                Y = rsh.spherical_harmonics_xyz([l_f], r @ o3.rot(*abc).t())
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W1 = torch.einsum("zij,jk->zik", (W, D_in))

                Y = rsh.spherical_harmonics_xyz([l_f], r)
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W2 = torch.einsum("ij,zjk->zik", (D_out, W))

                self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
예제 #20
0
def adjusted_projection(vectors, lmax):
    """
    :param vectors: tensor of shape [..., xyz]
    :return: tensor of shape [l * m]
    """
    vectors = vectors.reshape(-1, 3)
    radii = vectors.norm(2, -1)  # [batch]
    vectors = vectors[radii > 0]  # [batch, 3]

    coeff = projection(vectors, lmax)  # [batch, l * m]
    A = torch.einsum(
        "ai,bi->ab", rsh.spherical_harmonics_xyz(list(range(lmax + 1)),
                                                 vectors), coeff)
    coeff *= torch.lstsq(radii, A).solution.reshape(-1).unsqueeze(-1)
    return coeff.sum(0)
예제 #21
0
    def forward(self, r, r_eps=0, **_kwargs):
        """
        :param r: tensor [..., 3]
        :return: tensor [..., l_out * mul_out * m_out, l_in * mul_in * m_in]
        """
        *size, xyz = r.size()
        assert xyz == 3
        r = r.reshape(-1, 3)

        radii = r.norm(2, dim=1)  # [batch]

        # (1) Case r > 0

        # precompute all needed spherical harmonics
        Y = rsh.spherical_harmonics_xyz(
            self.Ls, r[radii > r_eps])  # [batch, l_filter * m_filter]

        # Normalize the spherical harmonics
        if self.normalization == 'component':
            Y.mul_(math.sqrt(4 * math.pi))
        if self.normalization == 'norm':
            diag = math.sqrt(4 * math.pi) * torch.cat([
                torch.ones(2 * l + 1) / math.sqrt(2 * l + 1)
                for _, l, _ in self.Rs_f
            ])
            Y.mul_(diag)

        # use the radial model to fix all the degrees of freedom
        # note: for the normalization we assume that the variance of R[i] is one
        R = self.R(radii[radii > r_eps]
                   )  # [batch, l_out * l_in * mul_out * mul_in * l_filter]

        RY = rsh.mul_radial_angular(self.Rs_f, R, Y)

        if Y.shape[0] == 0:
            kernel1 = torch.zeros(0, rs.dim(self.Rs_out), rs.dim(self.Rs_in))
        else:
            kernel1 = self.tp.right(RY)

        # (2) Case r = 0

        kernel2 = self.linear()

        kernel = r.new_zeros(len(r), *kernel2.shape)
        kernel[radii > r_eps] = kernel1
        kernel[radii <= r_eps] = kernel2

        return kernel.reshape(*size, *kernel2.shape)
예제 #22
0
    def forward(self, features, edge_index, edge_r, sh=None, size=None, n_norm=1):
        # features = [num_atoms, dim(Rs_in)]
        if sh is None:
            sh = rsh.spherical_harmonics_xyz(self.Rs_sh, edge_r, "component")  # [num_messages, dim(Rs_sh)]
        sh = sh / n_norm**0.5

        w = self.rm(edge_r.norm(dim=1))  # [num_messages, nweight]

        self_interation = self.lin1(features)
        features = self.propagate(edge_index, size=size, x=features, sh=sh, w=w)
        features = self.lin2(features)
        has_self_interaction = torch.cat([
            torch.ones(mul * (2 * l + 1)) if any(l_in == l and p_in == p for _, l_in, p_in in self.Rs_in) else torch.zeros(mul * (2 * l + 1))
            for mul, l, p in self.Rs_out
        ])
        return 0.5**0.5 * self_interation + (1 + (0.5**0.5 - 1) * has_self_interaction) * features
예제 #23
0
def test_sh_closure():
    """
    integral of Ylm * Yjn = delta_lj delta_mn
    integral of 1 over the unit sphere = 4 pi
    """
    with o3.torch_default_dtype(torch.float64):
        x = torch.randn(300000, 3)
        Ys = [rsh.spherical_harmonics_xyz([l], x) for l in range(0, 3 + 1)]
        for l1, Y1 in enumerate(Ys):
            for l2, Y2 in enumerate(Ys):
                m = (Y1.reshape(-1, 2 * l1 + 1, 1) * Y2.reshape(-1, 1, 2 * l2 + 1)).mean(0) * 4 * math.pi
                if l1 == l2:
                    i = torch.eye(2 * l1 + 1)
                    assert (m - i).abs().max() < 0.01
                else:
                    assert m.abs().max() < 0.01
예제 #24
0
def adjusted_projection(vectors, lmax):
    """
    :param vectors: tensor of shape [..., xyz]
    :return: tensor of shape [l * m]
    """
    vectors = vectors.reshape(-1, 3)
    radii = vectors.norm(2, -1)  # [batch]
    vectors = vectors[radii > 0]  # [batch, 3]

    coeff = rsh.spherical_harmonics_xyz(list(range(lmax + 1)),
                                        vectors)  # [batch, l * m]
    A = torch.einsum("ai,bi->ab", coeff, coeff)
    # Y(v_a) . Y(v_b) solution_b = radii_a
    solution = torch.lstsq(radii, A).solution.reshape(-1)  # [b]
    assert (radii - A @ solution).abs().max() < 1e-5 * radii.abs().max()

    return solution @ coeff
예제 #25
0
 def from_geometry(cls, vectors, lmax, p=0, adjusted=True):
     """
     :param vectors: tensor of vectors (p=-1) or pseudovectors (p=1) of shape [..., 3=xyz]
     """
     if adjusted:
         signal = adjusted_projection(vectors, lmax)
     else:
         vectors = vectors.reshape(-1, 3)
         r = vectors.norm(dim=1)
         sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), vectors)
         # 0.5 * sum_a ( Y(v_a) . sum_b r_b Y(v_b) s - r_a )^2
         A = torch.einsum('ai,b,bi->a', sh, r, sh)
         # 0.5 * sum_a ( A_a s - r_a )^2
         # sum_a A_a^2 s = sum_a A_a r_a
         s = torch.dot(A, r) / A.norm().pow(2)
         signal = s * torch.einsum('a,ai->i', r, sh)
     return cls(signal, p_val=1, p_arg=p)
예제 #26
0
    def forward(self, features, edge_index, edge_r, size=None, n_norm=1, custom_backward=False):
        """
        :param features: Tensor of shape [n_target, dim(Rs_in)]
        :param edge_index: LongTensor of shape [2, num_edges] ~ [a, b]
                           edge_index[0] = sources (convolution centers)
                           edge_index[1] = targets (neighbors)
        :param edge_r: Tensor of shape [num_edges, 3]
                       edge_r = position_target - position_source
        :param size: n_points or None
        :param n_norm: typical number of targets per source

        :return: Tensor of shape [n_points, dim(Rs_out)]
        """
        assert edge_r.shape[1] == 3

        radii = edge_r.norm(2, dim=-1) 

        # precompute all needed spherical harmonics
        y = rsh.spherical_harmonics_xyz(self.set_of_l_filters, edge_r)  # [batch, a, b, l_filter * m_filter]

        y[radii == 0] = 0

        # use the radial model to fix all the degrees of freedom
        # note: for the normalization we assume that the variance of R[i] is one
        r = self.R(radii.flatten()).reshape(*radii.shape, -1)  # [*_, n_edges, l_out * l_in * mul_out * mul_in * l_filter]
        r = r.clone()
        r[radii == 0] = 0

        if custom_backward:
            assert False, "Custom backward for sparse kernel: not coded yet!"
            #output = KernelConvFn.apply(
            #    features, edge_index, y, r, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters
            #)
        else:
            output = kernel_conv_fn_forward(
                features, edge_index, y, r, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters
            )
        
        output.div_(n_norm ** 0.5)

        # Case r > 0
        #if radii.shape[1] == radii.shape[2]:
        output += torch.einsum('ij,aj->ai', self.linear(), features)

        return output 
예제 #27
0
    def forward(self,
                features,
                edge_index,
                edge_r,
                sh=None,
                size=None,
                n_norm=1):
        if sh is None:
            sh = rsh.spherical_harmonics_xyz(
                self.Rs_sh, edge_r, "component")  # [num_messages, dim(Rs_sh)]
        sh = sh / n_norm**0.5

        w = self.rm(edge_r.norm(dim=1))  # [num_messages, nweight]

        features = self.propagate(edge_index,
                                  size=size,
                                  x=features,
                                  sh=sh,
                                  w=w)
        features = self.lin(features)
        return features
예제 #28
0
    def test1(self):
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

        torch.set_default_dtype(torch.float64)
        Rs_in = [(1, 0), (1, 1), (1, 2), (1, 0)]
        Rs_out = [(1, 0), (1, 1), (2, 0), (1, 2)]
        KC = KernelConv(Rs_in, Rs_out, ConstantRadialModel,
                        partial(o3.selection_rule_in_out_sh,
                                lmax=1)).to(device)

        n_path = 0
        for mul_out, l_out, p_out in KC.Rs_out:
            for mul_in, l_in, p_in in KC.Rs_in:
                l_filters = KC.selection_rule(l_in, p_in, l_out, p_out)
                n_path += mul_out * mul_in * len(l_filters)

        batch = 1
        atoms = 3

        F = torch.randn(batch, atoms, dim(Rs_in),
                        requires_grad=True).to(device)
        geo = torch.randn(batch, atoms, 3)
        r = (geo.unsqueeze(1) - geo.unsqueeze(2)).to(device)
        Y = rsh.spherical_harmonics_xyz(
            KC.set_of_l_filters, r)  # [batch, a, b, l_filter * m_filter]
        Y[r.norm(2, dim=-1) == 0] = 0
        Y = Y.clone().detach().requires_grad_(True).to(device)
        R = torch.randn(batch, atoms, atoms, n_path, requires_grad=True).to(
            device
        )  # [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter]

        inputs = (F, Y, R, KC.norm_coef, KC.Rs_in, KC.Rs_out,
                  KC.selection_rule, KC.set_of_l_filters)
        self.assertTrue(torch.autograd.gradcheck(KernelConvFn.apply, inputs))
예제 #29
0
def main():
    torch.set_default_dtype(torch.float64)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)

    x = torch.ones(4, 1)
    Rs_in = [(1, 0, 1)]
    r_max = 1.1

    tetris, labels = get_dataset()
    tetris_dataset = [
        dh.DataNeighbors(x, shape, r_max, y=label)
        for shape, label in zip(tetris, labels)
    ]

    Rs_out = [(1, 0, -1), (6, 0, 1)]
    lmax = 3

    f = MLNetwork(Rs_in, Rs_out, Convolution,
                  partial(make_gated_block, mul=16, lmax=lmax), 2)
    f = f.to(device)

    batch = Batch.from_data_list(tetris_dataset)
    batch = batch.to(device)
    sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), batch.edge_attr,
                                     'component')

    optimizer = torch.optim.Adam(f.parameters(), lr=3e-3)

    wall = time.perf_counter()
    for step in range(100):
        out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3)
        out = scatter_add(out, batch.batch, dim=0)
        out = torch.tanh(out)

        acc = out.cpu().round().eq(labels).double().mean().item()

        r_tetris_dataset = [
            dh.DataNeighbors(x, shape, r_max, y=label)
            for shape, label in zip(*get_dataset())
        ]
        r_batch = Batch.from_data_list(r_tetris_dataset)
        r_batch = r_batch.to(device)
        r_sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)),
                                           r_batch.edge_attr, 'component')

        with torch.no_grad():
            r_out = f(r_batch.x,
                      r_batch.edge_index,
                      r_batch.edge_attr,
                      sh=r_sh,
                      n_norm=3)
            r_out = scatter_add(r_out, r_batch.batch, dim=0)
            r_out = torch.tanh(r_out)

        loss = (out - labels).pow(2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(
            "wall={:.1f} step={} loss={:.2e} accuracy={:.2f} equivariance error={:.1e}"
            .format(time.perf_counter() - wall, step, loss.item(), acc,
                    (out - r_out).pow(2).mean().sqrt().item()))

    print(labels.numpy().round(1))
    print(out.detach().numpy().round(1))