Example #1
0
    def protodensity(self, coordinates: torch.Tensor, labels: torch.Tensor,
                     centers: torch.Tensor):
        """
        evaluates protomolecule density at coordinates

        :param coordinates:
        :param labels:
        :param centers:
        :return:
        """
        dv = distance_vectors(sample_coords=coordinates,
                              mol_coords=centers,
                              labels=labels,
                              device=self.device)
        r = distance(dv)
        density = torch.zeros_like(r)
        p_splitted = self.p.split(1, dim=1)
        b_splitted = self.b.split(1, dim=1)
        a_splitted = self.a.split(1, dim=1)
        for i, (p, b, a) in enumerate(zip(p_splitted, b_splitted, a_splitted)):
            p_s = expand_parameter(labels, p)
            a_s = expand_parameter(labels, a)
            b_s = expand_parameter(labels, b)
            k = gen_exponential_kernel(r, a_s, b_s, p_s)
            density += k

        return density.sum(2)
Example #2
0
    def forward(self, coordinates: torch.Tensor, coefficients: torch.Tensor,
                labels: torch.Tensor, centers: torch.Tensor):
        """
        evaluates density at coordinates

        :param coordinates:
        :param coefficients:
        :param labels:
        :param centers:
        :return:
        """
        dv = distance_vectors(sample_coords=coordinates,
                              mol_coords=centers,
                              labels=labels,
                              device=self.device)
        r = distance(dv)
        density = torch.zeros_like(r)
        c_splitted = coefficients.split(1, dim=2)
        p_splitted = self.p.split(1, dim=1)
        b_splitted = self.b.split(1, dim=1)
        a_splitted = self.a.split(1, dim=1)
        for i, (c_s, p, b, a) in enumerate(
                zip(c_splitted, p_splitted, b_splitted, a_splitted)):
            p_s = expand_parameter(labels, p)
            a_s = expand_parameter(labels, a)
            b_s = expand_parameter(labels, b)
            k = gen_exponential_kernel(r, a_s, b_s, p_s)
            c_s = c_s.transpose(1, 2)
            density += c_s * k

        return density.sum(2)
Example #3
0
    def forward(self, l, t, r, c_i, c_a, x):
        """

        :param l: labels
        :param t: topology
        :param r: molecular coordinates (angstroms)
        :param c_i: isotropic coefficients
        :param c_a: anisotropic coefficients
        :param x: sample coordinates (atomic units)
        :return:
        """

        r = r * UNITS_TABLE['angstrom']['au']
        c_i = torch.split(c_i, 1, 2)
        c_a = torch.chunk(c_a, 2, 2)
        c_a_f = torch.split(c_a[0], 1, 2)
        c_a_b = torch.split(c_a[1], 1, 2)

        dv = distance_vectors(x, r, l, self.device)
        d = distance(dv)
        z1 = angle(dv, d, x, r, t, self.device)
        t_flipped = t.flip(2)
        z2 = angle(dv, d, x, r, t_flipped, self.device)

        p = torch.zeros((x.size(0), x.size(1)),
                        dtype=torch.float,
                        device=self.device)

        for fun, pos in zip(self.functions, self.positions):

            if fun.frozen:
                if type(fun) is A2MDtIso:
                    p += fun.forward(l, d).sum(2)
                    continue
                else:
                    raise NotImplementedError(
                        "only frozen isotropic functions are considered")

            if type(fun) is A2MDtIso:

                c_i_ = c_i[pos].transpose(1, 2)
                p += (c_i_ * fun.forward(l, d)).sum(2)

            elif type(fun) is A2MDtAniso:

                c_a_forward = c_a_f[pos].transpose(1, 2)
                c_a_reverse = c_a_b[pos].transpose(1, 2)

                d_selected_forward, d_selected_reverse = select_distances(
                    d, t, device=self.device)
                l_selected_forward, l_selected_backwards = select_labels(
                    l, t, device=self.device)

                p += (c_a_forward * fun.forward(l_selected_forward,
                                                d_selected_forward, z1)).sum(2)
                p += (c_a_reverse * fun.forward(l_selected_backwards,
                                                d_selected_reverse, z2)).sum(2)

        return p
Example #4
0
    def forward_core(self, l, r, x):

        dv = distance_vectors(x, r, l, self.device)
        d = distance(dv)
        p = torch.zeros((x.size(0), x.size(1)),
                        dtype=torch.float,
                        device=self.device)
        for fun in self.functions:

            if fun.frozen:
                if type(fun) is A2MDtIso:
                    p += fun.forward(l, d).sum(2)

        return p