コード例 #1
0
ファイル: test_SO3Scalar.py プロジェクト: zizai/cormorant
    def test_SO3Scalar_check_cplx_fail(self, batch, maxl, channels):
        tau = [channels] * (maxl+1)
        rand_scalar = [torch.rand(batch + (t, 1, 1)) for l, t in enumerate(tau)]

        with pytest.raises(ValueError) as e:
            SO3Scalar(rand_scalar)

        rand_scalar = [torch.rand(batch + (t, 1, 3)) for l, t in enumerate(tau)]

        with pytest.raises(ValueError) as e:
            SO3Scalar(rand_scalar)
コード例 #2
0
ファイル: test_so3_torch.py プロジェクト: zizai/cormorant
    def test_so3_scalar_so3_scalar_mul(self, maxl, num_middle):
        middle_dims = (3,) * num_middle
        scalar_size = (2,) + middle_dims + (4, 2)
        scalar1 = SO3Scalar([torch.randn(scalar_size) for i in range(maxl)])
        scalar1_numpy = [numpy_from_complex(ti) for ti in scalar1]
        scalar2 = SO3Scalar([torch.randn(scalar_size) for i in range(maxl)])
        scalar2_numpy = [numpy_from_complex(ti) for ti in scalar2]

        true_complex_product = [part1 * part2 for (part1, part2) in zip(scalar1_numpy, scalar2_numpy)]
        so3scalar_product = scalar1 * scalar2
        so3scalar_product_numpy = [numpy_from_complex(ti) for ti in so3scalar_product]
        for exp_prod, true_prod in zip(so3scalar_product_numpy, true_complex_product):
            assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)
コード例 #3
0
ファイル: test_SO3Scalar.py プロジェクト: zizai/cormorant
    def test_SO3Scalar_check_batch_fail(self, batch, channels):

        maxl = len(batch) - 1

        tau = torch.randint(1, channels+1, [maxl+1])

        rand_scalar = [torch.rand(b + (t, 2)) for l, (b, t) in enumerate(zip(batch, tau))]

        if len(set(batch)) == 1:
            SO3Scalar(rand_scalar)
        else:
            with pytest.raises(ValueError) as e:
                SO3Scalar(rand_scalar)
コード例 #4
0
ファイル: test_SO3Scalar.py プロジェクト: zizai/cormorant
    def test_SO3Scalar_init_arb_tau(self, batch, maxl, channels):

        tau_list = torch.randint(1, channels+1, [maxl+1])

        test_vec = SO3Scalar.rand(batch, tau_list)

        assert test_vec.tau == tau_list
コード例 #5
0
ファイル: test_SO3Scalar.py プロジェクト: zizai/cormorant
    def test_SO3Scalar_init_channels(self, batch, maxl, channels):

        tau_list = [channels]*(maxl+1)

        test_vec = SO3Scalar.rand(batch, tau_list)

        assert test_vec.tau == tau_list
コード例 #6
0
def test_mix_SO3Scalar(batch, maxl, channels1, channels2):

    tau_in = [channels1] * (maxl + 1)
    tau_out = [channels2] * (maxl + 1)

    test_scalar = SO3Scalar.rand(batch, tau_in)
    test_weight = SO3Weight.rand(tau_in, tau_out)

    print(test_scalar.shapes, test_weight.shapes)
    mix(test_weight, test_scalar)
コード例 #7
0
ファイル: test_SO3Scalar.py プロジェクト: zizai/cormorant
    def test_SO3Scalar_mul_scalar(self, batch, maxl, channels):
        tau = [channels] * (maxl+1)

        vec0 = SO3Scalar([torch.rand(batch + (t, 2)) for l, t in enumerate(tau)])

        vec1 = 2 * vec0
        assert all(torch.allclose(2*part0, part1) for part0, part1 in zip(vec0, vec1))

        vec1 = vec0 * 2.0
        assert all(torch.allclose(2*part0, part1) for part0, part1 in zip(vec0, vec1))
コード例 #8
0
ファイル: test_SO3Scalar.py プロジェクト: zizai/cormorant
    def test_SO3Scalar_mul_list(self, batch, maxl, channels):
        tau = [channels] * (maxl+1)

        vec0 = SO3Scalar([torch.rand(batch + (t, 2)) for l, t in enumerate(tau)])

        scalar = [torch.rand(1).item() for _ in vec0]

        vec1 = scalar * vec0
        assert all(torch.allclose(s*part0, part1) for part0, s, part1 in zip(vec0, scalar, vec1))

        vec1 = vec0 * scalar
        assert all(torch.allclose(part0*s, part1) for part0, s, part1 in zip(vec0, scalar, vec1))
コード例 #9
0
ファイル: position_levels.py プロジェクト: zizai/cormorant
    def forward(self, norms, edge_mask):
        # Shape to resize at end
        s = norms.shape

        # Mask and reshape
        edge_mask = (edge_mask * (norms > 0)).unsqueeze(-1)
        norms = norms.unsqueeze(-1)

        # Get inverse powers
        rad_powers = torch.stack([
            torch.where(edge_mask, norms.pow(-pow), self.zero)
            for pow in range(self.rpow + 1)
        ],
                                 dim=-1)

        # Calculate trig functions
        rad_trig = torch.where(
            edge_mask, torch.sin((2 * pi * self.scales) * norms + self.phases),
            self.zero).unsqueeze(-1)

        # Take the product of the radial powers and the trig components and reshape
        rad_prod = (rad_powers * rad_trig).view(s + (
            1,
            2 * self.num_rad,
        ))

        # Apply linear mixing function, if desired
        if self.mix == 'cplx':
            radial_functions = [
                linear(rad_prod).view(s + (self.num_channels, 2))
                for linear in self.linear
            ]
        elif self.mix == 'real':
            radial_functions = [
                linear(rad_prod).view(s + (self.num_channels, ))
                for linear in self.linear
            ]
            # Hack because real-valued SO3Scalar class has not been implemented yet.
            # TODO: Implement real-valued SO3Scalar and fix this...
            radial_functions = [
                torch.stack([rad, torch.zeros_like(rad)], dim=-1)
                for rad in radial_functions
            ]
        else:
            radial_functions = [rad_prod.view(s + (self.num_rad, 2))
                                ] * (self.max_sh + 1)

        return SO3Scalar(radial_functions)
コード例 #10
0
ファイル: test_so3_torch.py プロジェクト: zizai/cormorant
    def test_so3_scalar_so3_vector_mul(self, maxl, num_middle):
        middle_dims = (3,) * num_middle
        scalar = SO3Scalar([torch.randn((2,) + middle_dims + (4, 2)) for i in range(maxl)])
        scalar_numpy = [numpy_from_complex(ti) for ti in scalar]
        vector = SO3Vec([torch.randn((2,) + middle_dims + (4, 2 * i+1, 2)) for i in range(maxl)])
        vector_numpy = [numpy_from_complex(ti) for ti in vector]

        true_complex_product = [np.expand_dims(part1, -1) * part2 for (part1, part2) in zip(scalar_numpy, vector_numpy)]
        so3sv_product = vector * scalar
        so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product]
        for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product):
            assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)

        so3sv_product = scalar * vector
        so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product]
        for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product):
            assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)
コード例 #11
0
    def forward(self, reps):
        """
        Performs the forward pass.

        Parameters
        ----------
        reps : :class:`SO3Vec <cormorant.so3_lib.SO3Vec>`
            Input SO3 Vector. 
        
        Returns
        -------
        dot_products : :class:`SO3Scalar <cormorant.so3_lib.SO3Scalar>`
            SO3 scalars representing a Matrix of form :math:`(\psi_i \cdot \psi_j)_c`, where c is a channel index with :math:`|C| = \sum_l \tau_l`.
        """
        if self.tau_in is not None and self.tau_in != reps.tau:
            raise ValueError(
                'Initialized tau not consistent with tau from forward! {} {}'.
                format(self.tau_in, reps.tau))

        signs = self.signs
        conj = self.conj

        reps1 = [part.unsqueeze(-4) for part in reps]
        reps2 = [part.unsqueeze(-5) for part in reps]

        reps2 = [part.flip(-2) * sign for part, sign in zip(reps2, signs)]

        dot_product_r = [(part1 * part2 * conj).sum(dim=(-2, -1))
                         for part1, part2 in zip(reps1, reps2)]
        dot_product_i = [(part1 * part2.flip(-1)).sum(dim=(-2, -1))
                         for part1, part2 in zip(reps1, reps2)]

        dot_products = [
            torch.stack([prod_r, prod_i], dim=-1)
            for prod_r, prod_i in zip(dot_product_r, dot_product_i)
        ]

        if self.cat:
            dot_products = torch.cat(dot_products, dim=-2)
            dot_products = [dot_products] * len(reps)

        return SO3Scalar(dot_products)
コード例 #12
0
    def test_covariance(self, tau, num_channels, maxl, basis, edge_net_type,
                        sample_batch):
        # env = build_environment(tau, maxl, num_channels)
        # datasets, data, num_species, charge_scale, sph_harms = env
        data, __, __ = sample_batch
        device, dtype = data['positions'].device, data['positions'].dtype
        sph_harms = SphericalHarmonicsRel(maxl - 1,
                                          conj=True,
                                          device=device,
                                          dtype=dtype,
                                          cg_dict=None)
        batch_size, natoms = data['positions'].shape[:2]
        D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype)
        # Setup Input
        atom_reps, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input(
            data, tau, maxl)
        atom_positions_rot = rot.rotate_cart_vec(R, atom_positions)
        atom_reps_rot = atom_reps.apply_wigner(D)

        # Calculate spherical harmonics and radial functions
        __, norms = sph_harms(atom_positions, atom_positions)
        __, norms_rot = sph_harms(atom_positions_rot, atom_positions_rot)

        rad_funcs = RadialFilters([maxl - 1], [basis, basis], [num_channels],
                                  1,
                                  device=device,
                                  dtype=dtype)
        rad_func_levels = rad_funcs(norms, edge_mask * (norms > 0))
        tau_pos = rad_funcs.tau[0]

        # Build the initial edge network
        if edge_net_type is None:
            edge_reps = None
        elif edge_net_type == 'rand':
            reps = [
                torch.randn((batch_size, natoms, natoms, tau, 2))
                for i in range(maxl)
            ]
            edge_reps = SO3Scalar(reps)
        else:
            raise ValueError

        # Build Edge layer
        tlist = [tau] * maxl
        tau_atoms = tlist
        tau_edge = tlist
        if edge_net_type is None:
            tau_edge = []

        edge_lvl = CormorantEdgeLevel(tau_atoms,
                                      tau_edge,
                                      tau_pos,
                                      num_channels,
                                      maxl,
                                      cutoff_type='soft',
                                      device=device,
                                      dtype=dtype,
                                      hard_cut_rad=1.73,
                                      soft_cut_rad=1.73,
                                      soft_cut_width=0.2)

        output_edge_reps = edge_lvl(edge_reps, atom_reps, rad_func_levels[0],
                                    edge_mask, norms)
        output_edge_reps_rot = edge_lvl(edge_reps, atom_reps_rot,
                                        rad_func_levels[0], edge_mask, norms)

        for i in range(maxl):
            assert (torch.max(
                torch.abs(output_edge_reps[i] - output_edge_reps_rot[i])) <
                    1E-5)
コード例 #13
0
ファイル: test_SO3Scalar.py プロジェクト: zizai/cormorant
import torch
import pytest

from cormorant.so3_lib import SO3Tau, SO3Scalar, SO3Scalar

rand_scalar = lambda batch, tau: SO3Scalar([torch.rand(batch + (t, 2)) for l, t in enumerate(tau)])

class TestSO3Scalar():

    @pytest.mark.parametrize('batch', [(1,), (2,), (7,), (1,1), (2, 2), (7, 7)])
    @pytest.mark.parametrize('maxl', range(3))
    @pytest.mark.parametrize('channels', range(1, 3))
    def test_SO3Scalar_init_channels(self, batch, maxl, channels):

        tau_list = [channels]*(maxl+1)

        test_vec = SO3Scalar.rand(batch, tau_list)

        assert test_vec.tau == tau_list

    @pytest.mark.parametrize('batch', [(1,), (2,), (7,), (1,1), (2, 2), (7, 7)])
    @pytest.mark.parametrize('maxl', range(4))
    @pytest.mark.parametrize('channels', range(1, 4))
    def test_SO3Scalar_init_arb_tau(self, batch, maxl, channels):

        tau_list = torch.randint(1, channels+1, [maxl+1])

        test_vec = SO3Scalar.rand(batch, tau_list)

        assert test_vec.tau == tau_list