示例#1
0
    def test_cg_dict_init(self, maxl):
        cg_dict = CGDict(maxl=maxl)

        assert len(cg_dict.keys()) == (maxl + 1)**2

        for key, val in cg_dict.items():
            assert val * val.t()
示例#2
0
    def test_cg_dict_device_dtype(self, maxl, device, dtype):

        if (device
                == torch.device('cuda')) and (not torch.cuda.is_available()):
            with pytest.raises(AssertionError) as e_info:
                cg_dict = CGDict(maxl=maxl, device=device, dtype=dtype)
        else:
            cg_dict = CGDict(maxl=maxl, device=device, dtype=dtype)
示例#3
0
    def test_cg_dict_to(self, maxl, dtype1, dtype2):

        cg_dict = CGDict(maxl=maxl, dtype=dtype1)

        assert cg_dict.dtype == cg_dict[(0, 0)].dtype
        assert cg_dict.dtype == dtype1

        cg_dict.to(dtype=dtype2)

        assert cg_dict.dtype == cg_dict[(0, 0)].dtype
        assert cg_dict.dtype == dtype2
示例#4
0
    def test_cg_dict_update_maxl(self, maxl1, maxl2):

        maxl = max(maxl1, maxl2)

        cg_dict = CGDict(maxl=maxl1)
        cg_dict.update_maxl(maxl2)

        assert cg_dict.maxl == maxl
        assert set(cg_dict.keys()) == {(l1, l2)
                                       for l1 in range(maxl + 1)
                                       for l2 in range(maxl + 1)}
示例#5
0
    def _init_cg_dict(self, cg_dict, maxl):
        """
        Initialize the Clebsch-Gordan dictionary.

        If cg_dict is set, check the following::
        - The dtype of cg_dict matches with self.
        - The devices of cg_dict matches with self.
        - The desired :maxl: <= :cg_dict.maxl: so that the CGDict will contain
            all necessary coefficients

        If :cg_dict: is not set, but :maxl: is set, get the cg_dict from a
        dict of global CGDict() objects.
        """
        # If cg_dict is defined, check it has the right properties
        if cg_dict is not None:
            if cg_dict.dtype != self.dtype:
                raise ValueError(
                    'CGDict dtype ({}) not match CGModule() dtype ({})'.format(
                        cg_dict.dtype, self.dtype))

            if cg_dict.device != self.device:
                raise ValueError(
                    'CGDict device ({}) not match CGModule() device ({})'.
                    format(cg_dict.device, self.device))

            if maxl is None:
                Warning(
                    'maxl is not defined, setting maxl based upon CGDict maxl ({}!'
                    .format(cg_dict.maxl))

            elif maxl > cg_dict.maxl:
                Warning(
                    'CGDict maxl ({}) is smaller than CGModule() maxl ({}). Updating!'
                    .format(cg_dict.maxl, maxl))
                cg_dict.update_maxl(maxl)

            self.cg_dict = cg_dict
            self._maxl = maxl

        # If cg_dict is not defined, but
        elif cg_dict is None and maxl is not None:

            self.cg_dict = CGDict(maxl=maxl,
                                  device=self.device,
                                  dtype=self.dtype)
            self._maxl = maxl

        else:
            self.cg_dict = None
            self._maxl = None
示例#6
0
    def test_cg_product_covariance(self, maxl_cg, maxl1, maxl2, channels,
                                   batch):
        maxl = max(maxl_cg, maxl1, maxl2)

        cg_dict = CGDict(maxl=maxl, dtype=torch.double)
        rand_rep = lambda tau, batch: [
            torch.rand(batch + (t, 2 * l + 1, 2)).double()
            for l, t in enumerate(tau)
        ]

        angles = torch.rand(3)
        D, R = gen_rot(angles, maxl)

        tau1 = [channels] * (maxl1 + 1)
        tau2 = [channels] * (maxl2 + 1)

        rep1 = rand_rep(tau1, (batch, ))
        rep2 = rand_rep(tau2, (batch, ))

        cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_cg)

        cg_prod_rot_out = rot.rotate_rep(D, cg_prod)

        rep1_rot = rot.rotate_rep(D, rep1)
        rep2_rot = rot.rotate_rep(D, rep2)

        cg_prod_rot_in = cg_product(cg_dict, rep1_rot, rep2_rot, maxl=maxl_cg)

        for part1, part2 in zip(cg_prod_rot_out, cg_prod_rot_in):
            assert torch.allclose(part1, part2)
示例#7
0
    def test_cg_dict_uninit(self):
        cg_dict = CGDict()

        assert (not cg_dict)

        with pytest.raises(ValueError) as e_info:
            cg_dict[(0, 0)]
示例#8
0
    def test_spherical_harmonics_vs_scipy(self, maxl, batch, conj):
        cg_dict = CGDict(maxl=maxl, dtype=torch.double)

        pos = torch.rand(batch + (3, ), dtype=torch.double)

        sh = spherical_harmonics(cg_dict,
                                 pos,
                                 maxl,
                                 normalize=True,
                                 conj=conj,
                                 sh_norm='qm')

        sh_sp = sph_harms_from_scipy(pos, maxl, conj=conj)

        for part1, part2 in zip(sh, sh_sp):
            assert torch.allclose(part1, part2)
示例#9
0
    def test_cg_mod_set_from_cg_dict(self, maxl, dtype):

        cg_dict = CGDict(maxl=1, dtype=torch.float)

        if dtype in [torch.half, torch.double]:
            # If data type in CGModule does not match CGDict, throw an errror
            with pytest.raises(ValueError):
                cg_mod = CGModule(maxl=maxl, dtype=dtype, cg_dict=cg_dict)
        else:
            cg_mod = CGModule(maxl=maxl, dtype=dtype, cg_dict=cg_dict)

            assert cg_mod.dtype == torch.float if dtype is None else dtype
            assert cg_mod.device == torch.device('cpu')
            assert cg_mod.maxl == maxl if maxl is not None else 1
            assert cg_mod.cg_dict
            assert cg_mod.cg_dict.maxl == max(1,
                                              maxl) if maxl is not None else 1
示例#10
0
    def test_spherical_rel_harmonics_vs_scipy(self, maxl, batch, natoms1,
                                              natoms2, conj):
        cg_dict = CGDict(maxl=maxl, dtype=torch.double)

        pos1 = torch.rand(batch + (natoms1, 3), dtype=torch.double)
        pos2 = torch.rand(batch + (natoms2, 3), dtype=torch.double)

        sh, norms = spherical_harmonics_rel(cg_dict,
                                            pos1,
                                            pos2,
                                            maxl,
                                            conj=conj,
                                            sh_norm='qm')

        sh_sp, norms_sp = sph_harms_rel_from_scipy(pos1, pos2, maxl, conj=conj)

        for l, (part1, part2) in enumerate(zip(sh, sh_sp)):
            if l == 0:
                continue
            assert torch.allclose(part1, part2)

        assert torch.allclose(norms, norms_sp)
示例#11
0
    def test_cg_product_dict_maxl(self, maxl_dict, maxl_prod, maxl1, maxl2,
                                  chan, batch):
        cg_dict = CGDict(maxl=maxl_dict, dtype=torch.double)

        tau1, tau2 = [chan] * (maxl1 + 1), [chan] * (maxl2 + 1)

        rep1 = SO3Vec.rand(batch, tau1, dtype=torch.double)
        rep2 = SO3Vec.rand(batch, tau2, dtype=torch.double)

        if all(maxl_dict >= maxl for maxl in [maxl_prod, maxl1, maxl2]):
            cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod)

        else:
            with pytest.raises(ValueError) as e_info:
                cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod)

                tau_out = cg_prod.tau
                tau_pred = cg_product_tau(tau1, tau2)

                # Test to make sure the output type matches the expected output type
                assert list(tau_out) == list(tau_pred)

            assert str(e_info.value).startswith('CG Dictionary maxl')
示例#12
0
    def test_spherical_harmonics(self, maxl, channels, conj):
        D, R, angles = rot.gen_rot(maxl, dtype=torch.double)
        D = SO3WignerD(D)

        if not conj:
            R = R.t()

        pos = torch.randn((channels, 3), dtype=torch.double)
        posr = rot.rotate_cart_vec(R, pos)

        cg_dict = CGDict(maxl, dtype=torch.double)

        sph_harms = spherical_harmonics(cg_dict, pos, maxl, conj=conj)
        sph_harmsr = spherical_harmonics(cg_dict, posr, maxl, conj=conj)

        dir = 'left' if conj else 'right'

        sph_harmsd = so3_torch.apply_wigner(D, sph_harms, dir=dir)

        # diff = (sph_harmsr - sph_harmsd).abs()
        diff = [(p1 - p2).abs().max()
                for p1, p2 in zip(sph_harmsr, sph_harmsd)]
        print(diff)
        assert all([d < 1e-6 for d in diff])
示例#13
0
    def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels):
        maxl_all = max(maxl1, maxl2, maxl)
        D, R, _ = rot.gen_rot(maxl_all)

        cg_dict = CGDict(maxl=maxl_all, dtype=torch.double)
        cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict)

        tau1 = SO3Tau([channels] * (maxl1 + 1))
        tau2 = SO3Tau([channels] * (maxl2 + 1))

        vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double)
        vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double)

        vec1i = vec1.apply_wigner(D, dir='left')
        vec2i = vec2.apply_wigner(D, dir='left')

        vec_prod = cg_prod(vec1, vec2)
        veci_prod = cg_prod(vec1i, vec2i)

        vecf_prod = vec_prod.apply_wigner(D, dir='left')

        # diff = (sph_harmsr - sph_harmsd).abs()
        diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)]
        assert all([d < 1e-6 for d in diff])
示例#14
0
class CGModule(nn.Module):
    """
    Clebsch-Gordan module. This functions identically to a normal PyTorch
    nn.Module, except for it adds the ability to specify a
    Clebsch-Gordan dictionary, and has additional tracking behavior set up
    to allow the CG Dictionary to be compatible with the DataParallel module.

    If `cg_dict` is specified upon instantiation, then the specified
    `cg_dict` is set as the Clebsch-Gordan dictionary for the CG module.

    If `cg_dict` is not specified, and `maxl` is specified, then CGModule
    will attempt to set the local `cg_dict` based upon the global
    `cormorant.cg_lib.global_cg_dicts`. If the dictionary has not been initialized
    with the appropriate `dtype`, `device`, and `maxl`, it will be initialized
    and stored in the `global_cg_dicts`, and then set to the local `cg_dict`.

    In this way, if there are many modules that need `CGDicts`, only a single
    `CGDict` will be initialized and automatically set up.

    Parameters
    ----------
    cg_dict : :class:`CGDict`, optional
        Specify an input CGDict to use for Clebsch-Gordan operations.
    maxl : :class:`int`, optional
        Maximum weight to initialize the Clebsch-Gordan dictionary.
    device : :class:`torch.torch.device`, optional
        Device to initialize the module and Clebsch-Gordan dictionary to.
    dtype : :class:`torch.torch.dtype`, optional
        Data type to initialize the module and Clebsch-Gordan dictionary to.
    """
    def __init__(self,
                 cg_dict=None,
                 maxl=None,
                 device=None,
                 dtype=None,
                 *args,
                 **kwargs):
        self._init_device_dtype(device, dtype)
        self._init_cg_dict(cg_dict, maxl)

        super().__init__(*args, **kwargs)

    def _init_device_dtype(self, device, dtype):
        """
        Initialize the default device and data type.

        device : :class:`torch.torch.device`, optional
            Set device for CGDict and related. If unset defaults to torch.device('cpu').

        dtype : :class:`torch.torch.dtype`, optional
            Set device for CGDict and related. If unset defaults to torch.float.

        """
        if device is None:
            self._device = torch.device('cpu')
        else:
            self._device = device

        if dtype is None:
            self._dtype = torch.float
        else:
            if not (dtype == torch.half or dtype == torch.float
                    or dtype == torch.double):
                raise ValueError(
                    'CG Module only takes internal data types of half/float/double. Got: {}'
                    .format(dtype))
            self._dtype = dtype

    def _init_cg_dict(self, cg_dict, maxl):
        """
        Initialize the Clebsch-Gordan dictionary.

        If cg_dict is set, check the following::
        - The dtype of cg_dict matches with self.
        - The devices of cg_dict matches with self.
        - The desired :maxl: <= :cg_dict.maxl: so that the CGDict will contain
            all necessary coefficients

        If :cg_dict: is not set, but :maxl: is set, get the cg_dict from a
        dict of global CGDict() objects.
        """
        # If cg_dict is defined, check it has the right properties
        if cg_dict is not None:
            if cg_dict.dtype != self.dtype:
                raise ValueError(
                    'CGDict dtype ({}) not match CGModule() dtype ({})'.format(
                        cg_dict.dtype, self.dtype))

            if cg_dict.device != self.device:
                raise ValueError(
                    'CGDict device ({}) not match CGModule() device ({})'.
                    format(cg_dict.device, self.device))

            if maxl is None:
                Warning(
                    'maxl is not defined, setting maxl based upon CGDict maxl ({}!'
                    .format(cg_dict.maxl))

            elif maxl > cg_dict.maxl:
                Warning(
                    'CGDict maxl ({}) is smaller than CGModule() maxl ({}). Updating!'
                    .format(cg_dict.maxl, maxl))
                cg_dict.update_maxl(maxl)

            self.cg_dict = cg_dict
            self._maxl = maxl

        # If cg_dict is not defined, but
        elif cg_dict is None and maxl is not None:

            self.cg_dict = CGDict(maxl=maxl,
                                  device=self.device,
                                  dtype=self.dtype)
            self._maxl = maxl

        else:
            self.cg_dict = None
            self._maxl = None

    @property
    def device(self):
        return self._device

    @property
    def dtype(self):
        return self._dtype

    @property
    def maxl(self):
        return self._maxl

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)

        device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)

        if self.cg_dict is not None:
            self.cg_dict.to(device=device, dtype=dtype)

        if device is not None:
            self._device = device

        if dtype is not None:
            self._dtype = dtype

        return self

    def cuda(self, device=None):
        if device is None:
            device = torch.device('cuda')
        elif device in range(torch.cuda.device_count()):
            device = torch.device('cuda:{}'.format(device))
        else:
            ValueError('Incorrect choice of device!')

        super().cuda(device=device)

        if self.cg_dict is not None:
            self.cg_dict.to(device=device)

        self._device = device

        return self

    def cpu(self):
        super().cpu()

        if self.cg_dict is not None:
            self.cg_dict.to(device=torch.device('cpu'))

        self._device = torch.device('cpu')

        return self

    def half(self):
        super().half()

        if self.cg_dict is not None:
            self.cg_dict.to(dtype=torch.half)

        self._dtype = torch.half

        return self

    def float(self):
        super().float()

        if self.cg_dict is not None:
            self.cg_dict.to(dtype=torch.float)

        self._dtype = torch.float

        return self

    def double(self):
        super().double()

        if self.cg_dict is not None:
            self.cg_dict.to(dtype=torch.double)

        self._dtype = torch.double

        return self
示例#15
0
    def test_no_maxl_w_cg_dict(self, maxl):
        cg_dict = CGDict(maxl=maxl)
        cg_prod = CGProduct(cg_dict=cg_dict)

        assert cg_prod.cg_dict is not None
        assert cg_prod.maxl is not None
示例#16
0
    def test_cg_dict_init(self, maxl):
        cg_dict = CGDict(maxl=maxl)

        assert set(cg_dict.keys()) == {(l1, l2)
                                       for l1 in range(maxl + 1)
                                       for l2 in range(maxl + 1)}
示例#17
0
    def test_cg_aggregate(self, maxl_dict, maxl_prod, maxl1, maxl2, chan,
                          batch, atom1, atom2):
        if any(maxl_dict < maxl for maxl in [maxl_prod, maxl1, maxl2]):
            return

        cg_dict = CGDict(maxl=maxl_dict, dtype=torch.double)
        rand_rep = lambda tau, batch: [
            torch.rand(batch + (t, 2 * l + 1, 2)).double()
            for l, t in enumerate(tau)
        ]

        tau1, tau2 = [chan] * (maxl1 + 1), [chan] * (maxl2 + 1)

        batch1 = (batch, atom1, atom2)
        batch2 = (batch, atom2)

        rep1 = rand_rep(tau1, batch1)
        rep2 = rand_rep(tau2, batch2)

        # Calculate CG Aggregate and compare it to explicit calculation
        cg_agg = cg_product(cg_dict,
                            rep1,
                            rep2,
                            maxl=maxl_prod,
                            aggregate=True)

        cg_agg_explicit = [torch.zeros_like(p) for p in cg_agg]
        for bidx in range(batch):
            for aidx1 in range(atom1):
                for aidx2 in range(atom2):
                    rep1_sub = [p[bidx, aidx1, aidx2] for p in rep1]
                    rep2_sub = [p[bidx, aidx2] for p in rep2]
                    cg_out = cg_product(cg_dict,
                                        rep1_sub,
                                        rep2_sub,
                                        maxl=maxl_prod,
                                        aggregate=False)
                    out_ell = [(p.shape[-2] - 1) // 2 for p in cg_out]
                    for ell in out_ell:
                        cg_agg_explicit[ell][bidx, aidx1] += cg_out[ell]

        for part1, part2 in zip(cg_agg, cg_agg_explicit):
            assert torch.allclose(part1, part2)

        cg_agg = cg_product(cg_dict,
                            rep2,
                            rep1,
                            maxl=maxl_prod,
                            aggregate=True)

        cg_agg_explicit = [torch.zeros_like(p) for p in cg_agg]
        for bidx in range(batch):
            for aidx1 in range(atom1):
                for aidx2 in range(atom2):
                    rep1_sub = [p[bidx, aidx1, aidx2] for p in rep1]
                    rep2_sub = [p[bidx, aidx2] for p in rep2]
                    cg_out = cg_product(cg_dict,
                                        rep2_sub,
                                        rep1_sub,
                                        maxl=maxl_prod,
                                        aggregate=False)
                    out_ell = [(p.shape[-2] - 1) // 2 for p in cg_out]
                    for ell in out_ell:
                        cg_agg_explicit[ell][bidx, aidx1] += cg_out[ell]

        for part1, part2 in zip(cg_agg, cg_agg_explicit):
            assert torch.allclose(part1, part2)
示例#18
0
文件: agent.py 项目: gncs/molgym
    def __init__(
        self,
        observation_space: ObservationSpace,
        action_space: ActionSpace,
        min_max_distance: Tuple[float, float],
        network_width: int,
        maxl: int,
        num_cg_levels: int,
        num_channels_hidden: int,
        num_channels_per_element: int,
        num_gaussians: int,
        bag_scale: int,
        beta: Optional[float] = None,
        device=None,
    ):
        super().__init__(observation_space, action_space)
        self.device = device
        self.dtype = torch.float

        self.zs = self.observation_space.zs
        self.zs_tensor = torch.tensor(self.zs, dtype=self.dtype, device=self.device)

        self.min_distance, self.max_distance = min_max_distance
        assert self.min_distance < self.max_distance
        self.beta = beta

        self.max_sh = maxl
        self.num_cg_levels = num_cg_levels
        self.num_channels_hidden = num_channels_hidden
        self.num_channels_per_element = num_channels_per_element
        self.num_gaussians = num_gaussians

        self.num_channels_out = len(self.zs) * self.num_channels_per_element
        self.channel_offsets = torch.arange(start=0,
                                            end=self.num_channels_per_element,
                                            dtype=torch.long,
                                            device=self.device).unsqueeze(0)

        self.cg_dict = CGDict(maxl=self.max_sh, device=self.device, dtype=self.dtype)
        self.cg_model = Cormorant(
            maxl=self.max_sh,  # Cutoff in CG operations (default: [3])
            max_sh=self.max_sh,  # Number of spherical harmonic powers to use (default: [3])
            num_cg_levels=self.num_cg_levels,  # Number of CG levels (default: 4)
            num_channels=[self.num_channels_hidden] * self.num_cg_levels + [self.num_channels_out],
            num_species=len(self.zs),
            cutoff_type=['soft'],  # Types of cutoffs to include
            hard_cut_rad=min(self.max_distance, 2.1),  # Radius of hard cutoff (in AA)
            soft_cut_rad=min(self.max_distance, 2.1),  # Radius of soft cutoff (in AA)
            soft_cut_width=0.2,  # Width of SOFT cutoff in Angstroms (default: 0.2)
            weight_init='rand',  # Weight initialization function to use (default: rand)
            level_gain=[10.0],  # Gain at each level (default: [10.])
            charge_power=2,  # Maximum power to take in one-hot (default: 2)
            basis_set=[3, 3],  # Use gaussian mask instead of sigmoid mask.
            charge_scale=max(self.zs),
            bag_scale=bag_scale,
            device=self.device,
            dtype=self.dtype,
            cg_dict=self.cg_dict,
        )

        self.cg_mix = CormorantMixer(
            tau_in=SO3Tau([self.num_channels_per_element] * (self.max_sh + 1)),
            tau_other=SO3Tau([self.num_channels_per_element]),
            maxl=self.max_sh,
            num_channels=self.num_channels_per_element,
            level_gain=10.0,
            weight_init='rand',
            device=self.device,
            dtype=self.dtype,
            cg_dict=self.cg_dict,
        )

        self.sph_harms = SphericalHarmonics(maxl=self.max_sh,
                                            conj=False,
                                            sh_norm='qm',
                                            device=self.device,
                                            dtype=self.dtype,
                                            cg_dict=self.cg_dict)

        self.atomic_scalars = AtomicScalars(maxl=self.max_sh, full_scalars=True, device=self.device, dtype=self.dtype)

        self.num_latent = self.atomic_scalars.get_output_dim(self.num_channels_out)
        self.num_latent_element = self.atomic_scalars.get_output_dim(self.num_channels_per_element)

        # Focus
        self.phi_focus = MLP(
            input_dim=self.num_latent,
            output_dims=(network_width, 1),
        )

        # Element
        self.phi_element = MLP(
            input_dim=self.num_latent,
            output_dims=(network_width, len(self.zs)),
        )

        # Distance: Gaussian Mixture Model
        self.phi_d = MLP(
            input_dim=self.num_latent_element,
            output_dims=(network_width, 2 * self.num_gaussians),
        )
        self.pad_zeros = torch.nn.ConstantPad1d(padding=(0, 1), value=0.0)  # Pad with one 0.0 to the right

        self.distance_half_width = torch.tensor((self.max_distance - self.min_distance) / 2,
                                                dtype=self.dtype,
                                                device=self.device)
        self.distance_center = torch.tensor((self.min_distance + self.max_distance) / 2,
                                            dtype=self.dtype,
                                            device=self.device)

        self.distance_log_stds = torch.nn.Parameter(torch.log(
            torch.tensor([0.1] * self.num_gaussians, dtype=self.dtype, device=self.device)),
                                                    requires_grad=True)  # (gaussians, )

        # Value function
        self.phi_trans = MLP(
            input_dim=self.num_latent,
            output_dims=(network_width, network_width),
        )
        self.phi_v = MLP(
            input_dim=network_width,
            output_dims=(network_width, 1),
        )

        self.to(self.device)