Esempio n. 1
0
    def test_SO3Vec_cat(self, batch1, batch2, batch3, channels1, channels2, channels3, maxl1, maxl2, maxl3):
        tau1 = [channels1] * (maxl1+1)
        tau2 = [channels2] * (maxl2+1)
        tau3 = [channels3] * (maxl2+1)

        tau12 = SO3Tau.cat([tau1, tau2])
        tau123 = SO3Tau.cat([tau1, tau2, tau3])

        vec1 = SO3Vec.randn(tau1, batch1)
        vec2 = SO3Vec.randn(tau2, batch2)
        vec3 = SO3Vec.randn(tau3, batch3)

        if batch1 == batch2:
            vec12 = so3_torch.cat([vec1, vec2])

            assert vec12.tau == tau12
        else:
            with pytest.raises(RuntimeError):
                vec12 = so3_torch.cat([vec1, vec2])

        if batch1 == batch2 == batch3:
            vec123 = so3_torch.cat([vec1, vec2, vec3])

            assert vec123.tau == tau123
        else:
            with pytest.raises(RuntimeError):
                vec12 = so3_torch.cat([vec1, vec2, vec3])
Esempio n. 2
0
    def test_channels(self):
        tau = SO3Tau([3, 2, 1])

        assert tau.channels == None

        tau = SO3Tau([3] * 2)

        assert tau.channels == 3
Esempio n. 3
0
    def __init__(self,
                 max_sh,
                 basis_set,
                 num_channels,
                 mix=False,
                 device=torch.device('cpu'),
                 dtype=torch.float):
        super(RadPolyTrig, self).__init__()

        trig_basis, rpow = basis_set
        self.rpow = rpow
        self.max_sh = max_sh

        assert (trig_basis >= 0 and rpow >= 0)

        self.num_rad = (trig_basis + 1) * (rpow + 1)
        self.num_channels = num_channels

        # This instantiates a set of functions sin(2*pi*n*x/a), cos(2*pi*n*x/a) with a=1.
        self.scales = torch.cat(
            [torch.arange(trig_basis + 1),
             torch.arange(trig_basis + 1)]).view(1, 1, 1, -1).to(device=device,
                                                                 dtype=dtype)
        self.phases = torch.cat([
            torch.zeros(trig_basis + 1), pi / 2 * torch.ones(trig_basis + 1)
        ]).view(1, 1, 1, -1).to(device=device, dtype=dtype)

        # This avoids the sin(0*r + 0) = 0 part from wasting computations.
        self.phases[0, 0, 0, 0] = pi / 2

        # Now, make the above learnable
        self.scales = nn.Parameter(self.scales)
        self.phases = nn.Parameter(self.phases)

        # If desired, mix the radial components to a desired shape
        self.mix = mix
        if (mix == 'cplx') or (mix is True):
            self.linear = nn.ModuleList([
                nn.Linear(2 * self.num_rad,
                          2 * self.num_channels).to(device=device, dtype=dtype)
                for _ in range(max_sh + 1)
            ])
            self.tau = SO3Tau((num_channels, ) * (max_sh + 1))
        elif mix == 'real':
            self.linear = nn.ModuleList([
                nn.Linear(2 * self.num_rad,
                          self.num_channels).to(device=device, dtype=dtype)
                for _ in range(max_sh + 1)
            ])
            self.tau = SO3Tau((num_channels, ) * (max_sh + 1))
        elif (mix == 'none') or (mix is False):
            self.linear = None
            self.tau = SO3Tau((self.num_rad, ) * (max_sh + 1))
        else:
            raise ValueError(
                'Can only specify mix = real, cplx, or none! {}'.format(mix))

        self.zero = torch.tensor(0, device=device, dtype=dtype)
Esempio n. 4
0
    def test_init(self, maxl, num_channels):
        tau = SO3Tau([num_channels] * (maxl + 1))

        assert type(tau) == SO3Tau
        assert list(tau) == [num_channels] * (maxl + 1)

        tau = SO3Tau(tau)

        assert type(tau) == SO3Tau
        assert list(tau) == [num_channels] * (maxl + 1)
Esempio n. 5
0
    def test_SO3Weight_init(self, maxl, channels1, channels2):

        tau1_list = SO3Tau([channels1] * (maxl + 1))
        tau2_list = SO3Tau([channels2] * (maxl + 1))

        weight_list = rand_weight_list(tau1_list, tau2_list)

        weight = SO3Weight(weight_list)

        assert isinstance(weight, SO3Weight)
        assert weight.tau_in == SO3Tau(tau1_list)
        assert weight.tau_out == SO3Tau(tau2_list)
Esempio n. 6
0
    def __init__(self, taus_in, maxl=None):
        super().__init__()

        self.taus_in = taus_in = [SO3Tau(tau) for tau in taus_in if tau]

        if maxl is None:
            maxl = max([tau.maxl for tau in taus_in])
        self.maxl = maxl

        self.tau_out = reduce(lambda x, y: x & y, taus_in)[:self.maxl + 1]
Esempio n. 7
0
    def __init__(self, tau_in=None, cat=True, device=None, dtype=None):
        super().__init__(device=device, dtype=dtype)
        self.tau_in = tau_in
        self.cat = cat

        if self.tau_in is not None:
            if cat:
                self.tau = SO3Tau([sum(tau_in)] * len(tau_in))
            else:
                self.tau = SO3Tau([t for t in tau_in])
            self.signs = [
                torch.tensor(-1.).pow(torch.arange(-ell, ell + 1).float()).to(
                    device=self.device, dtype=self.dtype).unsqueeze(-1)
                for ell in range(len(tau_in) + 1)
            ]
            self.conj = torch.tensor([1., -1.]).to(device=self.device,
                                                   dtype=self.dtype)
        else:
            self.tau = None
            self.signs = None
Esempio n. 8
0
    def forward(self, rep):
        """
        Linearly mix a represention.

        Parameters
        ----------
        rep : :obj:`list` of :obj:`torch.Tensor`
            Representation to mix.

        Returns
        -------
        rep : :obj:`list` of :obj:`torch.Tensor`
            Mixed representation.
        """
        if SO3Tau.from_rep(rep) != self.tau_in:
            raise ValueError('Tau of input rep does not match initialized tau!'
                             ' rep: {} tau: {}'.format(SO3Tau.from_rep(rep),
                                                       self.tau_in))

        return so3_torch.mix(self.weights, rep)
Esempio n. 9
0
    def test_from_rep(self, batch, tau0):
        rand_rep = lambda tau, batch: [
            torch.rand(batch + (t, 2 * l + 1, 2)).double()
            for l, t in enumerate(tau)
        ]

        rep = rand_rep(tau0, batch)
        tau = SO3Tau.from_rep(rep)

        assert type(tau) == SO3Tau
        assert list(tau) == list(tau0)
Esempio n. 10
0
    def test_cat(self):
        tau1 = SO3Tau([1, 2, 3])
        tau2 = SO3Tau([1, 1])
        tau3 = SO3Tau([0, 0, 2])

        tau = SO3Tau.cat([tau1, tau2])
        assert list(tau) == [2, 3, 3]

        assert type(tau) == SO3Tau

        print(tau)

        tau = (tau1 & tau2)
        assert list(tau) == [2, 3, 3]

        tau1 &= tau2
        assert list(tau1) == [2, 3, 3]

        tau123 = (tau1 & tau2) & tau3

        assert SO3Tau.cat([tau1, tau2, tau3]) == tau123
Esempio n. 11
0
    def __init__(self,
                 tau_in,
                 tau_out,
                 real=False,
                 weight_init='randn',
                 gain=1,
                 device=None,
                 dtype=None):
        super().__init__(device=device, dtype=dtype)
        tau_in = SO3Tau(tau_in)
        tau_out = SO3Tau(tau_out) if type(tau_out) is not int else tau_out

        # Allow one to set the output tau to a pre-specified number of output channels.
        if type(tau_out) is int:
            tau_out = [tau_out] * len(tau_in)

        self.tau_in = SO3Tau(tau_in)
        self.tau_out = SO3Tau(tau_out)
        self.real = real

        if weight_init is 'randn':
            weights = SO3Weight.randn(self.tau_in,
                                      self.tau_out,
                                      device=device,
                                      dtype=dtype)
        elif weight_init is 'rand':
            weights = SO3Weight.rand(self.tau_in,
                                     self.tau_out,
                                     device=device,
                                     dtype=dtype)
            weights = 2 * weights - 1
        else:
            raise NotImplementedError(
                'weight_init can only be randn or rand for now')

        gain = [gain / max(shape) for shape in weights.shapes]
        weights = gain * weights

        self.weights = weights.as_parameter()
Esempio n. 12
0
def cg_product_tau(tau1, tau2, maxl=inf):
    """
    Calulate output multiplicity of the CG Product of two SO3 Vectors
    given the multiplicty of two input SO3 Vectors.

    Parameters
    ----------
    tau1 : :class:`list` of :class:`int`, :class:`SO3Tau`.
        Multiplicity of first representation.

    tau2 : :class:`list` of :class:`int`, :class:`SO3Tau`.
        Multiplicity of second representation.

    maxl : :class:`int`
        Largest weight to include in CG Product.

    Return
    ------

    tau : :class:`SO3Tau`
        Multiplicity of output representation.

    """
    tau1 = SO3Tau(tau1)
    tau2 = SO3Tau(tau2)

    L1, L2 = tau1.maxl, tau2.maxl
    L = min(L1 + L2, maxl)

    tau = [0] * (L + 1)

    for l1 in range(L1 + 1):
        for l2 in range(L2 + 1):
            lmin, lmax = abs(l1 - l2), min(l1 + l2, maxl)
            for l in range(lmin, lmax + 1):
                tau[l] += tau1[l1]

    return SO3Tau(tau)
Esempio n. 13
0
    def test_add(self):
        tau1 = SO3Tau([1, 2, 3])
        tau2 = SO3Tau([1, 1])

        tau = tau1 + tau2
        assert type(tau) == SO3Tau
        assert list(tau) == list(tau1) + list(tau2)

        tau = tuple(tau1) + tau2
        assert type(tau) == SO3Tau
        assert list(tau) == list(tau1) + list(tau2)

        tau = list(tau1) + tau2
        assert type(tau) == SO3Tau
        assert list(tau) == list(tau1) + list(tau2)

        tau1p = SO3Tau(tau1)
        tau1p += tau2
        assert type(tau1p) == SO3Tau
        assert list(tau1p) == list(tau1) + list(tau2)

        tau = sum([SO3Tau([3, 2, 1]), [1], (2, 3)])
        assert type(tau) == SO3Tau
        assert list(tau) == [3, 2, 1, 1, 2, 3]
Esempio n. 14
0
    def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels):
        maxl_all = max(maxl1, maxl2, maxl)
        D, R, _ = rot.gen_rot(maxl_all)

        cg_dict = CGDict(maxl=maxl_all, dtype=torch.double)
        cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict)

        tau1 = SO3Tau([channels] * (maxl1 + 1))
        tau2 = SO3Tau([channels] * (maxl2 + 1))

        vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double)
        vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double)

        vec1i = vec1.apply_wigner(D, dir='left')
        vec2i = vec2.apply_wigner(D, dir='left')

        vec_prod = cg_prod(vec1, vec2)
        veci_prod = cg_prod(vec1i, vec2i)

        vecf_prod = vec_prod.apply_wigner(D, dir='left')

        # diff = (sph_harmsr - sph_harmsd).abs()
        diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)]
        assert all([d < 1e-6 for d in diff])
Esempio n. 15
0
    def __init__(self,
                 taus_in,
                 tau_out,
                 maxl=None,
                 real=False,
                 weight_init='randn',
                 gain=1,
                 device=None,
                 dtype=None):
        super().__init__(device=device, dtype=dtype)

        self.cat_reps = CatReps(taus_in, maxl=maxl)
        self.mix_reps = MixReps(self.cat_reps.tau,
                                tau_out,
                                real=real,
                                weight_init=weight_init,
                                gain=gain,
                                device=device,
                                dtype=dtype)

        self.taus_in = taus_in
        self.tau_out = SO3Tau(self.mix_reps)
Esempio n. 16
0
 def tau(self):
     return SO3Tau([self.channels_out])
Esempio n. 17
0
 def test_apply_euler(self, batch, channels, maxl):
     tau = SO3Tau([channels] * (maxl + 1))
     vec = SO3Vec.rand(batch, tau, dtype=torch.double)
     wigner = SO3WignerD.euler(maxl, dtype=torch.double)
     so3_torch.apply_wigner(wigner, vec)
Esempio n. 18
0
File: agent.py Progetto: gncs/molgym
    def __init__(
        self,
        observation_space: ObservationSpace,
        action_space: ActionSpace,
        min_max_distance: Tuple[float, float],
        network_width: int,
        maxl: int,
        num_cg_levels: int,
        num_channels_hidden: int,
        num_channels_per_element: int,
        num_gaussians: int,
        bag_scale: int,
        beta: Optional[float] = None,
        device=None,
    ):
        super().__init__(observation_space, action_space)
        self.device = device
        self.dtype = torch.float

        self.zs = self.observation_space.zs
        self.zs_tensor = torch.tensor(self.zs, dtype=self.dtype, device=self.device)

        self.min_distance, self.max_distance = min_max_distance
        assert self.min_distance < self.max_distance
        self.beta = beta

        self.max_sh = maxl
        self.num_cg_levels = num_cg_levels
        self.num_channels_hidden = num_channels_hidden
        self.num_channels_per_element = num_channels_per_element
        self.num_gaussians = num_gaussians

        self.num_channels_out = len(self.zs) * self.num_channels_per_element
        self.channel_offsets = torch.arange(start=0,
                                            end=self.num_channels_per_element,
                                            dtype=torch.long,
                                            device=self.device).unsqueeze(0)

        self.cg_dict = CGDict(maxl=self.max_sh, device=self.device, dtype=self.dtype)
        self.cg_model = Cormorant(
            maxl=self.max_sh,  # Cutoff in CG operations (default: [3])
            max_sh=self.max_sh,  # Number of spherical harmonic powers to use (default: [3])
            num_cg_levels=self.num_cg_levels,  # Number of CG levels (default: 4)
            num_channels=[self.num_channels_hidden] * self.num_cg_levels + [self.num_channels_out],
            num_species=len(self.zs),
            cutoff_type=['soft'],  # Types of cutoffs to include
            hard_cut_rad=min(self.max_distance, 2.1),  # Radius of hard cutoff (in AA)
            soft_cut_rad=min(self.max_distance, 2.1),  # Radius of soft cutoff (in AA)
            soft_cut_width=0.2,  # Width of SOFT cutoff in Angstroms (default: 0.2)
            weight_init='rand',  # Weight initialization function to use (default: rand)
            level_gain=[10.0],  # Gain at each level (default: [10.])
            charge_power=2,  # Maximum power to take in one-hot (default: 2)
            basis_set=[3, 3],  # Use gaussian mask instead of sigmoid mask.
            charge_scale=max(self.zs),
            bag_scale=bag_scale,
            device=self.device,
            dtype=self.dtype,
            cg_dict=self.cg_dict,
        )

        self.cg_mix = CormorantMixer(
            tau_in=SO3Tau([self.num_channels_per_element] * (self.max_sh + 1)),
            tau_other=SO3Tau([self.num_channels_per_element]),
            maxl=self.max_sh,
            num_channels=self.num_channels_per_element,
            level_gain=10.0,
            weight_init='rand',
            device=self.device,
            dtype=self.dtype,
            cg_dict=self.cg_dict,
        )

        self.sph_harms = SphericalHarmonics(maxl=self.max_sh,
                                            conj=False,
                                            sh_norm='qm',
                                            device=self.device,
                                            dtype=self.dtype,
                                            cg_dict=self.cg_dict)

        self.atomic_scalars = AtomicScalars(maxl=self.max_sh, full_scalars=True, device=self.device, dtype=self.dtype)

        self.num_latent = self.atomic_scalars.get_output_dim(self.num_channels_out)
        self.num_latent_element = self.atomic_scalars.get_output_dim(self.num_channels_per_element)

        # Focus
        self.phi_focus = MLP(
            input_dim=self.num_latent,
            output_dims=(network_width, 1),
        )

        # Element
        self.phi_element = MLP(
            input_dim=self.num_latent,
            output_dims=(network_width, len(self.zs)),
        )

        # Distance: Gaussian Mixture Model
        self.phi_d = MLP(
            input_dim=self.num_latent_element,
            output_dims=(network_width, 2 * self.num_gaussians),
        )
        self.pad_zeros = torch.nn.ConstantPad1d(padding=(0, 1), value=0.0)  # Pad with one 0.0 to the right

        self.distance_half_width = torch.tensor((self.max_distance - self.min_distance) / 2,
                                                dtype=self.dtype,
                                                device=self.device)
        self.distance_center = torch.tensor((self.min_distance + self.max_distance) / 2,
                                            dtype=self.dtype,
                                            device=self.device)

        self.distance_log_stds = torch.nn.Parameter(torch.log(
            torch.tensor([0.1] * self.num_gaussians, dtype=self.dtype, device=self.device)),
                                                    requires_grad=True)  # (gaussians, )

        # Value function
        self.phi_trans = MLP(
            input_dim=self.num_latent,
            output_dims=(network_width, network_width),
        )
        self.phi_v = MLP(
            input_dim=network_width,
            output_dims=(network_width, 1),
        )

        self.to(self.device)
Esempio n. 19
0
 def tau(self):
     return SO3Tau([])