def protodensity(self, coordinates: torch.Tensor, labels: torch.Tensor, centers: torch.Tensor): """ evaluates protomolecule density at coordinates :param coordinates: :param labels: :param centers: :return: """ dv = distance_vectors(sample_coords=coordinates, mol_coords=centers, labels=labels, device=self.device) r = distance(dv) density = torch.zeros_like(r) p_splitted = self.p.split(1, dim=1) b_splitted = self.b.split(1, dim=1) a_splitted = self.a.split(1, dim=1) for i, (p, b, a) in enumerate(zip(p_splitted, b_splitted, a_splitted)): p_s = expand_parameter(labels, p) a_s = expand_parameter(labels, a) b_s = expand_parameter(labels, b) k = gen_exponential_kernel(r, a_s, b_s, p_s) density += k return density.sum(2)
def forward(self, coordinates: torch.Tensor, coefficients: torch.Tensor, labels: torch.Tensor, centers: torch.Tensor): """ evaluates density at coordinates :param coordinates: :param coefficients: :param labels: :param centers: :return: """ dv = distance_vectors(sample_coords=coordinates, mol_coords=centers, labels=labels, device=self.device) r = distance(dv) density = torch.zeros_like(r) c_splitted = coefficients.split(1, dim=2) p_splitted = self.p.split(1, dim=1) b_splitted = self.b.split(1, dim=1) a_splitted = self.a.split(1, dim=1) for i, (c_s, p, b, a) in enumerate( zip(c_splitted, p_splitted, b_splitted, a_splitted)): p_s = expand_parameter(labels, p) a_s = expand_parameter(labels, a) b_s = expand_parameter(labels, b) k = gen_exponential_kernel(r, a_s, b_s, p_s) c_s = c_s.transpose(1, 2) density += c_s * k return density.sum(2)
def forward(self, l, t, r, c_i, c_a, x): """ :param l: labels :param t: topology :param r: molecular coordinates (angstroms) :param c_i: isotropic coefficients :param c_a: anisotropic coefficients :param x: sample coordinates (atomic units) :return: """ r = r * UNITS_TABLE['angstrom']['au'] c_i = torch.split(c_i, 1, 2) c_a = torch.chunk(c_a, 2, 2) c_a_f = torch.split(c_a[0], 1, 2) c_a_b = torch.split(c_a[1], 1, 2) dv = distance_vectors(x, r, l, self.device) d = distance(dv) z1 = angle(dv, d, x, r, t, self.device) t_flipped = t.flip(2) z2 = angle(dv, d, x, r, t_flipped, self.device) p = torch.zeros((x.size(0), x.size(1)), dtype=torch.float, device=self.device) for fun, pos in zip(self.functions, self.positions): if fun.frozen: if type(fun) is A2MDtIso: p += fun.forward(l, d).sum(2) continue else: raise NotImplementedError( "only frozen isotropic functions are considered") if type(fun) is A2MDtIso: c_i_ = c_i[pos].transpose(1, 2) p += (c_i_ * fun.forward(l, d)).sum(2) elif type(fun) is A2MDtAniso: c_a_forward = c_a_f[pos].transpose(1, 2) c_a_reverse = c_a_b[pos].transpose(1, 2) d_selected_forward, d_selected_reverse = select_distances( d, t, device=self.device) l_selected_forward, l_selected_backwards = select_labels( l, t, device=self.device) p += (c_a_forward * fun.forward(l_selected_forward, d_selected_forward, z1)).sum(2) p += (c_a_reverse * fun.forward(l_selected_backwards, d_selected_reverse, z2)).sum(2) return p
def forward_core(self, l, r, x): dv = distance_vectors(x, r, l, self.device) d = distance(dv) p = torch.zeros((x.size(0), x.size(1)), dtype=torch.float, device=self.device) for fun in self.functions: if fun.frozen: if type(fun) is A2MDtIso: p += fun.forward(l, d).sum(2) return p