def test_enhance_atoms(): p_table = Get_periodic_table() ##### mols = dmy.get_rndethane_mols(distance=True) atoms = GNR.make_atom_df(mols) structure_dict = GNR.make_struc_dict(atoms) BCAI.enhance_structure_dict(structure_dict) ########### BCAI.enhance_atoms(atoms, structure_dict) for i, idx in enumerate(atoms['atom_index'].values): molid = atoms['molecule_name'][i] mol = 0 for ml_fnd in mols: if ml_fnd.molid == molid: mol = ml_fnd atid = atoms['atom_index'][i] assert p_table.index(atoms['typestr'][i]) == mol.types[atid] assert np.array_equal(atoms['conn'][i], mol.conn[atid]) assert np.array_equal(atoms['distance'][i], mol.dist[atid])
def test_get_atomic_ACSF(): mols = dmy.get_rndethane_mols() atoms = GNR.make_atom_df(mols) struc = GNR.make_struc_dict(atoms) bonds = GNR.make_bonds_df(mols) atoms = QML.get_atomic_qml_features(atoms, bonds, struc, featureflag='ACSF')
def test_make_triplets(): p_table = Get_periodic_table() ##### mols = dmy.get_rndethane_mols(distance=True) atoms = GNR.make_atom_df(mols) structure_dict = GNR.make_struc_dict(atoms) BCAI.enhance_structure_dict(structure_dict) BCAI.enhance_atoms(atoms, structure_dict) bonds = GNR.make_bonds_df(mols) BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH') ############ triplets = BCAI.make_triplets(bonds["molecule_name"].unique(), structure_dict) assert len(triplets["molecule_name"].unique()) == len(mols) count = 0 for mol in mols: for atom1, type1 in enumerate(mol.types): for atom2, type2 in enumerate(mol.types): if atom1 == atom2: continue for atom3, type3 in enumerate(mol.types): if atom3 in [atom1, atom2] or atom3 < atom2: continue if mol.conn[atom1][atom2] != 1 or mol.conn[atom1][ atom3] != 1: continue row = triplets.loc[(triplets.molecule_name == mol.molid) & (triplets.atom_index_0 == atom1) & (triplets.atom_index_1 == atom2) & (triplets.atom_index_2 == atom3)] assert len(row.index) == 1 ba = mol.xyz[atom2] - mol.xyz[atom1] bc = mol.xyz[atom3] - mol.xyz[atom1] angle = np.sum( ba * bc) / (np.linalg.norm(ba) * np.linalg.norm(bc)) angle = np.arccos(np.clip(angle, -1.0, 1.0)) assert angle == row.angle.values count += 1 assert count == len(triplets.index)
def get_features_frommols(self, args, params={}, molcheck_run=False, training=True, max=200): self.params = params target = flag_to_target(args['targetflag']) self.remove_mols(target) if molcheck_run: return for mol in self.mols: if len(mol.types) > max: max = len(mol.types) + 1 print('WARNING, setting max atoms to ', max) self.params['max'] = max self.atoms = GNR.make_atom_df(self.mols) self.struc = GNR.make_struc_dict(self.atoms) if len(args['targetflag']) == 4: self.bonds = GNR.make_bonds_df(self.mols) if args['featureflag'] in ['aSLATM', 'CMAT', 'FCHL', 'ACSF']: from autoenrich.ml.features import QML_features self.atoms = QML_features.get_atomic_qml_features( self.atoms, self.bonds, self.struc, featureflag=args['featureflag'], cutoff=params['cutoff'], max=max) elif args['featureflag'] == 'BCAI': from autoenrich.ml.features import TFM_features self.BCAI, self.atoms, self.bonds, self.struc, xfiles, rfiles, yfiles = TFM_features.get_BCAI_features( self.atoms, self.bonds, self.struc, targetflag=args['targetflag'], training=training) elif args['featureflag'] != 'dummy': return else: print('Feature flag not recognised, no feature flag: ', args['featureflag']) return 0
def test_make_struc_dict(): p_table = Get_periodic_table() mols = dmy.get_rndethane_mols() atoms = GNR.make_atom_df(mols) structure_dict = GNR.make_struc_dict(atoms) assert len(structure_dict.keys()) == len(mols) for mol in mols: assert structure_dict[mol.molid]['typesstr'] == [p_table[type] for type in mol.types] assert np.array_equal(structure_dict[mol.molid]['positions'], mol.xyz) assert np.array_equal(structure_dict[mol.molid]['conn'], mol.conn)
def test_get_atomic_FCHL(): mols = dmy.get_rndethane_mols() atoms = GNR.make_atom_df(mols) struc = GNR.make_struc_dict(atoms) bonds = GNR.make_bonds_df(mols) atoms = QML.get_atomic_qml_features(atoms, bonds, struc, featureflag='FCHL') assert atoms['atomic_rep'].values[0].shape == (5, 50)
def test_get_atomic_aSLATM(): mols = dmy.get_rndethane_mols() atoms = GNR.make_atom_df(mols) struc = GNR.make_struc_dict(atoms) bonds = GNR.make_bonds_df(mols) atoms = QML.get_atomic_qml_features(atoms, bonds, struc, featureflag='aSLATM') assert len(atoms['atomic_rep'].values[0]) == 20105
def test_enhance_bonds(): p_table = Get_periodic_table() ##### mols = dmy.get_rndethane_mols(distance=True) atoms = GNR.make_atom_df(mols) structure_dict = GNR.make_struc_dict(atoms) BCAI.enhance_structure_dict(structure_dict) BCAI.enhance_atoms(atoms, structure_dict) bonds = GNR.make_bonds_df(mols) ############ BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH') for idx, bond in enumerate(bonds): molid = bonds['molecule_name'][idx] at1 = bonds['atom_index_0'][idx] at2 = bonds['atom_index_1'][idx] mol = 0 for ml_fnd in mols: if ml_fnd.molid == molid: mol = ml_fnd assert mol.coupling_len[at1][at2] == int(bonds['type'][idx][0]) assert mol.coupling[at1][at2] == bonds['scalar_coupling_constant'][idx] if bonds['labeled_type'][idx] == '3JHH': assert bonds['predict'][idx] == 1 else: assert bonds['predict'][idx] == 0 for mol in mols: for atom1, type1 in enumerate(mol.types): for atom2, type2 in enumerate(mol.types): if atom1 == atom2: continue row = bonds.loc[(bonds['molecule_name'] == mol.molid) & (bonds['atom_index_0'] == atom1) & (bonds['atom_index_1'] == atom2)] cpl = row['scalar_coupling_constant'].values if type1 == 1 and type2 == 1 and mol.coupling_len[atom1][ atom2] == 3: assert row.predict.values == 1 assert mol.coupling[atom1][atom2] == cpl[0]
def test_get_atomic_cmat(): mols = dmy.get_rndethane_mols() atoms = GNR.make_atom_df(mols) struc = GNR.make_struc_dict(atoms) bonds = GNR.make_bonds_df(mols) atoms = QML.get_atomic_qml_features(atoms, bonds, struc, featureflag='CMAT') assert len(atoms['atomic_rep'].values[0]) == 50 * (50 + 1) / 2 assert len(atoms['atomic_rep'].values[0].nonzero()[0]) == len( mols[0].types) * (len(mols[0].types) + 1) / 2
def test_get_scaling(): ##### mols = dmy.get_rndethane_mols(distance=True) atoms = GNR.make_atom_df(mols) structure_dict = GNR.make_struc_dict(atoms) BCAI.enhance_structure_dict(structure_dict) BCAI.enhance_atoms(atoms, structure_dict) bonds = GNR.make_bonds_df(mols) BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH') ##### means, stds = BCAI.get_scaling(bonds)
def test_enhance_structure_dict(): p_table = Get_periodic_table() ##### mols = dmy.get_rndethane_mols(distance=True) atoms = GNR.make_atom_df(mols) structure_dict = GNR.make_struc_dict(atoms) ##### BCAI.enhance_structure_dict(structure_dict) for mol in mols: assert structure_dict[mol.molid]['typesstr'] == [ p_table[type] for type in mol.types ] assert np.array_equal(structure_dict[mol.molid]['positions'], mol.xyz) assert np.array_equal(structure_dict[mol.molid]['conn'], mol.conn) assert np.array_equal(structure_dict[mol.molid]['distances'], mol.dist)
def test_add_embedding(): p_table = Get_periodic_table() ##### mols = dmy.get_rndethane_mols(distance=True) atoms = GNR.make_atom_df(mols) structure_dict = GNR.make_struc_dict(atoms) BCAI.enhance_structure_dict(structure_dict) BCAI.enhance_atoms(atoms, structure_dict) bonds = GNR.make_bonds_df(mols) BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH') triplets = BCAI.make_triplets(bonds["molecule_name"].unique(), structure_dict) ##### embeddings, atoms, bonds, triplets = BCAI.add_embedding( atoms, bonds, triplets)
def test_make_atom_df(): mols = dmy.get_rndethane_mols() ats = 0 for mol in mols: ats += len(mol.types) atoms = GNR.make_atom_df(mols) assert len(atoms["molecule_name"].unique()) == len(mols) counted = 0 for i, idx in enumerate(atoms['atom_index'].values): for mol in mols: if atoms['molecule_name'][i] == mol.molid: counted += 1 assert np.array_equal(mol.conn[idx], atoms['conn'][i]) assert mol.xyz[idx][0] == atoms['x'][i] assert mol.xyz[idx][1] == atoms['y'][i] assert mol.xyz[idx][2] == atoms['z'][i] assert counted == ats