Ejemplo n.º 1
0
    def test1(self):
        """test gradients of the Kernel"""
        torch.set_default_dtype(torch.float64)
        Rs_in = [(1, 0), (1, 1), (1, 0), (1, 2)]
        Rs_out = [(1, 0), (1, 1), (1, 2), (1, 0)]
        kernel = Kernel(Rs_in, Rs_out, ConstantRadialModel,
                        partial(o3.selection_rule_in_out_sh, lmax=1))

        n_path = 0
        for mul_out, l_out, p_out in kernel.Rs_out:
            for mul_in, l_in, p_in in kernel.Rs_in:
                l_filters = kernel.selection_rule(l_in, p_in, l_out, p_out)
                n_path += mul_out * mul_in * len(l_filters)

        r = torch.randn(2, 3)
        Y = rsh.spherical_harmonics_xyz(kernel.set_of_l_filters,
                                        r)  # [l_filter * m_filter, batch]
        Y = Y.clone().detach().requires_grad_(True)
        R = torch.randn(
            2, n_path, requires_grad=True
        )  # [batch, l_out * l_in * mul_out * mul_in * l_filter]

        inputs = (Y, R, kernel.norm_coef, kernel.Rs_in, kernel.Rs_out,
                  kernel.selection_rule, kernel.set_of_l_filters)
        self.assertTrue(torch.autograd.gradcheck(KernelFn.apply, inputs))
Ejemplo n.º 2
0
    def test1(self):
        torch.set_default_dtype(torch.float64)
        Rs_in = [(1, 0), (1, 1), (2, 0), (1, 2)]
        Rs_out = [(2, 0), (1, 1), (1, 2), (3, 0)]
        kernel = Kernel(Rs_in, Rs_out, ConstantRadialModel)

        n_path = 0
        for mul_out, l_out, p_out in kernel.Rs_out:
            for mul_in, l_in, p_in in kernel.Rs_in:
                l_filters = kernel.get_l_filters(l_in, p_in, l_out, p_out)
                n_path += mul_out * mul_in * len(l_filters)

        for rg_Y, rg_R in [(True, True), (True, False), (False, True)]:
            r = torch.randn(2, 3)
            radii = r.norm(2, dim=1)  # [batch]
            Y = kernel.sh(kernel.set_of_l_filters,
                          r)  # [l_filter * m_filter, batch]
            Y = Y.clone().detach().requires_grad_(rg_Y)
            R = torch.randn(
                2, n_path, requires_grad=rg_R
            )  # [batch, l_out * l_in * mul_out * mul_in * l_filter]
            norm_coef = kernel.norm_coef
            norm_coef = norm_coef[:, :, (radii == 0).type(
                torch.long)]  # [l_out, l_in, batch]

            inputs = (Y, R, norm_coef, kernel.Rs_in, kernel.Rs_out,
                      kernel.get_l_filters, kernel.set_of_l_filters)
            self.assertTrue(torch.autograd.gradcheck(KernelFn.apply, inputs))
Ejemplo n.º 3
0
    def test_compare_forward(self):
        for normalization in ["norm", "component"]:
            torch.manual_seed(0)
            K = Kernel(self.Rs_in, self.Rs_out, RadialModel=ConstantRadialModel, normalization=normalization)
            new_features = K(self.geometry)

            torch.manual_seed(0)
            K2 = Kernel(self.Rs_in, self.Rs_out, RadialModel=ConstantRadialModel, normalization=normalization)
            check_new_features = K2(self.geometry, custom_backward=True)

            assert all(torch.all(a == b) for a, b in zip(K.parameters(), K.parameters())), self.msg
            self.assertTrue(torch.allclose(new_features, check_new_features))
Ejemplo n.º 4
0
def test_flow():
    """
    This test checks that information is flowing as expected from target to source.
    edge_index[0] is source (convolution center)
    edge_index[1] is target (neighbors)
    """

    edge_index = torch.LongTensor([
        [0, 0, 0, 0],
        [1, 2, 3, 4],
    ])
    features = torch.tensor([-1., 1., 1., 1., 1.])
    features = features.unsqueeze(-1)
    edge_r = torch.ones(edge_index.shape[-1], 3)

    Rs = [0]
    conv = Convolution(Kernel(Rs, Rs, ConstantRadialModel))
    conv.kernel.R.weight.data.fill_(1.)  # Fix weight to 1.

    output = conv(features, edge_index, edge_r)
    torch.allclose(output, torch.tensor([4., 0., 0., 0., 0.]).unsqueeze(-1))

    edge_index = torch.LongTensor([[1, 2, 3, 4], [0, 0, 0, 0]])
    output = conv(features, edge_index, edge_r)
    torch.allclose(output,
                   torch.tensor([0., -1., -1., -1., -1.]).unsqueeze(-1))
Ejemplo n.º 5
0
    def test1(self):
        with torch_default_dtype(torch.float64):
            Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)]
            Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)]

            f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5),
                           rescaled_act.sigmoid)
            c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel))

            abc = torch.randn(3)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 5, 3)

            rx = torch.einsum("ij,zaj->zai", (D_in, x))
            rgeo = geo @ o3.rot(*abc).t()

            y = f(c(x, geo), dim=2)
            ry = torch.einsum("ij,zaj->zai", (D_out, y))

            self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
Ejemplo n.º 6
0
    def __init__(self, Rs_in, Rs_out):
        super().__init__()

        # hack: use Kernel to construct the weight matrix
        def get_l_filters(l_in, l_out):
            return [0] if l_in == l_out else []

        self.kernel = Kernel(Rs_in, Rs_out, ConstantRadialModel, get_l_filters)
Ejemplo n.º 7
0
    def __init__(self, Rs_in, Rs_out, size, **kwargs):
        super().__init__()

        R = partial(CosineBasisModel, max_radius=1.0, number_of_basis=(size + 1) // 2, h=50, L=3, act=rescaled_act.relu)
        self.kernel = Kernel(Rs_in, Rs_out, R, normalization='component')
        x = torch.linspace(-1, 1, size)
        self.r = torch.stack(torch.meshgrid(x, x, x), dim=-1)
        self.kwargs = kwargs
Ejemplo n.º 8
0
def test_equivariance(Rs_in, Rs_out, n_source, n_target, n_edge):
    torch.set_default_dtype(torch.float64)

    mp = Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel))
    groups = 4
    mp_group = Convolution(
        GroupKernel(Rs_in, Rs_out,
                    partial(Kernel, RadialModel=ConstantRadialModel), groups))

    features = rs.randn(n_target, Rs_in)
    features2 = rs.randn(n_target, Rs_in * groups)

    r_source = torch.randn(n_source, 3)
    r_target = torch.randn(n_target, 3)

    edge_index = torch.stack([
        torch.randint(n_source, size=(n_edge, )),
        torch.randint(n_target, size=(n_edge, )),
    ])
    size = (n_target, n_source)

    if n_edge == 0:
        edge_r = torch.zeros(0, 3)
    else:
        edge_r = torch.stack(
            [r_target[j] - r_source[i] for i, j in edge_index.T])
    print(features.shape, edge_index.shape, edge_r.shape, size)
    out1 = mp(features, edge_index, edge_r, size=size)
    out1_groups = mp(features2, edge_index, edge_r, size=size, groups=groups)
    out1_kernel_groups = mp_group(features2,
                                  edge_index,
                                  edge_r,
                                  size=size,
                                  groups=groups)

    angles = o3.rand_angles()
    D_in = rs.rep(Rs_in, *angles)
    D_out = rs.rep(Rs_out, *angles)
    D_in_groups = rs.rep(Rs_in * groups, *angles)
    D_out_groups = rs.rep(Rs_out * groups, *angles)
    R = o3.rot(*angles)

    out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out
    out2_groups = mp(features2 @ D_in_groups.T,
                     edge_index,
                     edge_r @ R.T,
                     size=size,
                     groups=groups) @ D_out_groups
    out2_kernel_groups = mp_group(features2 @ D_in_groups.T,
                                  edge_index,
                                  edge_r @ R.T,
                                  size=size,
                                  groups=groups) @ D_out_groups

    assert (out1 - out2).abs().max() < 1e-10
    assert (out1_groups - out2_groups).abs().max() < 1e-10
    assert (out1_kernel_groups - out2_kernel_groups).abs().max() < 1e-10
Ejemplo n.º 9
0
    def test2(self):
        with torch_default_dtype(torch.float64):
            mul = 100000
            for l_in in range(4):
                Rs_in = [(mul, l_in)]
                for l_out in range(4):
                    Rs_out = [(1, l_out)]

                    k = Kernel(Rs_in,
                               Rs_out,
                               ConstantRadialModel,
                               normalization='norm')
                    k = k(torch.randn(1, 3))

                    self.assertLess(k.mean().item(), 1e-3)
                    self.assertAlmostEqual(k.var().item() * mul,
                                           1 / (2 * l_out + 1),
                                           places=1)
Ejemplo n.º 10
0
    def __init__(self, num_radial=30, max_radius=2):
        super().__init__()

        sp = rescaled_act.Softplus(beta=5)
        RadialModel = partial(CosineBasisModel,
                              max_radius=max_radius,
                              number_of_basis=num_radial,
                              h=100,
                              L=2,
                              act=sp)

        self.conv = Convolution(Kernel([(1, 0)], [(1, 1)], RadialModel))
Ejemplo n.º 11
0
def test_equivariance():
    torch.set_default_dtype(torch.float64)

    n_edge = 3
    n_source = 4
    n_target = 2

    Rs_in = [(3, 0), (0, 1)]
    Rs_mid1 = [(5, 0), (1, 1)]
    Rs_mid2 = [(5, 0), (1, 1), (1, 2)]
    Rs_out = [(5, 1), (3, 2)]

    convolution = lambda Rs_in, Rs_out: Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel))
    convolution_groups = lambda Rs_in, Rs_out: Convolution(
        GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups))
    groups = 4
    mp = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution)
    mp_groups = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution_groups)

    features = rs.randn(n_target, Rs_in)

    r_source = torch.randn(n_source, 3)
    r_target = torch.randn(n_target, 3)

    edge_index = torch.stack([
        torch.randint(n_source, size=(n_edge,)),
        torch.randint(n_target, size=(n_edge,)),
    ])
    size = (n_target, n_source)

    if n_edge == 0:
        edge_r = torch.zeros(0, 3)
    else:
        edge_r = torch.stack([
            r_target[j] - r_source[i]
            for i, j in edge_index.T
        ])
    print(features.shape, edge_index.shape, edge_r.shape, size)
    out1 = mp(features, edge_index, edge_r, size=size)
    out1_groups = mp_groups(features, edge_index, edge_r, size=size)

    angles = o3.rand_angles()
    D_in = rs.rep(Rs_in, *angles)
    D_out = rs.rep(Rs_out, *angles)
    R = o3.rot(*angles)

    out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out
    out2_groups = mp_groups(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out

    assert (out1 - out2).abs().max() < 1e-10
    assert (out1_groups - out2_groups).abs().max() < 1e-10
Ejemplo n.º 12
0
    def get_kernel_conv_kernelconv(self, seed, normalization):
        torch.manual_seed(seed)
        C = Convolution(
            Kernel(self.Rs_in,
                   self.Rs_out,
                   ConstantRadialModel,
                   normalization=normalization))

        torch.manual_seed(seed)
        KC = KernelConv(self.Rs_in,
                        self.Rs_out,
                        RadialModel=ConstantRadialModel,
                        normalization=normalization)
        return C, KC
Ejemplo n.º 13
0
 def layer(Rs1, Rs2):
     R = partial(GaussianRadialModel,
                 max_radius=max_radius,
                 number_of_basis=radial_basis,
                 h=radial_neurons,
                 L=radial_layers,
                 act=radial_act,
                 min_radius=min_radius)
     k = Kernel(Rs1,
                Rs2,
                R,
                partial(o3.selection_rule_in_out_sh, lmax=lmax),
                allow_unused_inputs=True)
     return Convolution(k)
Ejemplo n.º 14
0
    def test2(self):
        """Test rotation equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            k = Kernel(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            abc = torch.randn(3)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(o3.rot(*abc) @ r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
Ejemplo n.º 15
0
    def test4(self):
        """Test parity equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)]
            Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)]

            k = Kernel(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            D_in = o3.direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = o3.direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in Rs_out
                for _ in range(mul)
            ])

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(-r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())