def __init__(self, elements, params, device, frozen=False): """ Base clase for the density functions :param elements: :param params: :param device: :param frozen: """ self.device = device self.elements = convert_label2tensor([element2an(i) for i in elements], device=device) self.raw_params = params self.params = dict() for elem, item in params.items(): idx = elements.index(elem) for par, value in item.items(): try: self.params[par][idx] = value except KeyError: self.params[par] = torch.zeros(len(elements), dtype=torch.float, device=device) self.params[par][idx] = value if type(frozen) is bool: self.frozen = frozen else: raise IOError("frozen must be bolean")
def parametrize(self, mol, params): """ :param mol: :param params :return: """ if not isinstance(mol, Mol2): try: mol = Mol2(file=mol) except ValueError: raise IOError("unknown format for mol. Please, use Mol2 or str") coords = torch.tensor(mol.get_coordinates(), device=self.device, dtype=torch.float).unsqueeze(0) labels = convert_label2tensor(mol.get_atomic_numbers(), device=self.device).unsqueeze(0) connectivity = torch.tensor(mol.get_bonds(), dtype=torch.long, device=self.device).unsqueeze(0) charge = torch.tensor(mol.charges + mol.atomic_numbers, dtype=torch.float, device=self.device).unsqueeze(0) natoms = mol.get_number_atoms() nbonds = mol.get_number_bonds() connectivity -= 1 int_iso = torch.zeros(1, natoms, 2) int_aniso = torch.zeros(1, nbonds, 4) for fun in params: center = fun['center'] funtype, pos = match_fun_names(fun) if funtype == 'core': charge[0, center] -= integrate_from_dict(fun) elif funtype in 'iso': int_iso[0, center, pos] = integrate_from_dict(fun) elif funtype in 'aniso': idx, col = self.match_bond(connectivity, fun) int_aniso[0, idx, col + pos] = integrate_from_dict(fun) int_iso = int_iso.to(self.device) int_aniso = int_aniso.to(self.device) _, _, iso_out, aniso_out = self.model.forward_coefficients(labels, connectivity, coords, charge, int_iso, int_aniso) if iso_out.is_cuda: iso_out = iso_out.squeeze(0).cpu().data.numpy() aniso_out = aniso_out.squeeze(0).cpu().data.numpy() else: iso_out = iso_out.squeeze(0).data.numpy() aniso_out = aniso_out.squeeze(0).data.numpy() for fun in params: center = fun['center'] funtype, pos = match_fun_names(fun) if funtype == 'core': continue if funtype == 'iso': fun['coefficient'] = iso_out[center, pos].item() elif funtype == 'aniso': idx, col = self.match_bond(connectivity, fun) fun['coefficient'] = aniso_out[idx, col + pos].item() return params