コード例 #1
0
def test_enhance_atoms():

    p_table = Get_periodic_table()

    #####
    mols = dmy.get_rndethane_mols(distance=True)
    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    ###########

    BCAI.enhance_atoms(atoms, structure_dict)

    for i, idx in enumerate(atoms['atom_index'].values):

        molid = atoms['molecule_name'][i]
        mol = 0
        for ml_fnd in mols:
            if ml_fnd.molid == molid:
                mol = ml_fnd

        atid = atoms['atom_index'][i]
        assert p_table.index(atoms['typestr'][i]) == mol.types[atid]
        assert np.array_equal(atoms['conn'][i], mol.conn[atid])
        assert np.array_equal(atoms['distance'][i], mol.dist[atid])
コード例 #2
0
def test_get_atomic_ACSF():

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    struc = GNR.make_struc_dict(atoms)
    bonds = GNR.make_bonds_df(mols)

    atoms = QML.get_atomic_qml_features(atoms,
                                        bonds,
                                        struc,
                                        featureflag='ACSF')
コード例 #3
0
def test_make_triplets():

    p_table = Get_periodic_table()

    #####
    mols = dmy.get_rndethane_mols(distance=True)

    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    BCAI.enhance_atoms(atoms, structure_dict)

    bonds = GNR.make_bonds_df(mols)
    BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH')
    ############

    triplets = BCAI.make_triplets(bonds["molecule_name"].unique(),
                                  structure_dict)
    assert len(triplets["molecule_name"].unique()) == len(mols)

    count = 0
    for mol in mols:
        for atom1, type1 in enumerate(mol.types):
            for atom2, type2 in enumerate(mol.types):
                if atom1 == atom2:
                    continue

                for atom3, type3 in enumerate(mol.types):

                    if atom3 in [atom1, atom2] or atom3 < atom2:
                        continue

                    if mol.conn[atom1][atom2] != 1 or mol.conn[atom1][
                            atom3] != 1:
                        continue

                    row = triplets.loc[(triplets.molecule_name == mol.molid)
                                       & (triplets.atom_index_0 == atom1)
                                       & (triplets.atom_index_1 == atom2)
                                       & (triplets.atom_index_2 == atom3)]

                    assert len(row.index) == 1

                    ba = mol.xyz[atom2] - mol.xyz[atom1]
                    bc = mol.xyz[atom3] - mol.xyz[atom1]

                    angle = np.sum(
                        ba * bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))
                    angle = np.arccos(np.clip(angle, -1.0, 1.0))

                    assert angle == row.angle.values
                    count += 1

    assert count == len(triplets.index)
コード例 #4
0
    def get_features_frommols(self,
                              args,
                              params={},
                              molcheck_run=False,
                              training=True,
                              max=200):

        self.params = params

        target = flag_to_target(args['targetflag'])
        self.remove_mols(target)
        if molcheck_run:
            return

        for mol in self.mols:
            if len(mol.types) > max:
                max = len(mol.types) + 1
                print('WARNING, setting max atoms to ', max)

        self.params['max'] = max

        self.atoms = GNR.make_atom_df(self.mols)
        self.struc = GNR.make_struc_dict(self.atoms)
        if len(args['targetflag']) == 4:
            self.bonds = GNR.make_bonds_df(self.mols)

        if args['featureflag'] in ['aSLATM', 'CMAT', 'FCHL', 'ACSF']:
            from autoenrich.ml.features import QML_features
            self.atoms = QML_features.get_atomic_qml_features(
                self.atoms,
                self.bonds,
                self.struc,
                featureflag=args['featureflag'],
                cutoff=params['cutoff'],
                max=max)

        elif args['featureflag'] == 'BCAI':
            from autoenrich.ml.features import TFM_features
            self.BCAI, self.atoms, self.bonds, self.struc, xfiles, rfiles, yfiles = TFM_features.get_BCAI_features(
                self.atoms,
                self.bonds,
                self.struc,
                targetflag=args['targetflag'],
                training=training)

        elif args['featureflag'] != 'dummy':
            return

        else:
            print('Feature flag not recognised, no feature flag: ',
                  args['featureflag'])
            return 0
コード例 #5
0
def test_make_struc_dict():

    p_table = Get_periodic_table()

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    assert len(structure_dict.keys()) == len(mols)

    for mol in mols:
        assert structure_dict[mol.molid]['typesstr'] == [p_table[type] for type in mol.types]
        assert np.array_equal(structure_dict[mol.molid]['positions'], mol.xyz)
        assert np.array_equal(structure_dict[mol.molid]['conn'], mol.conn)
コード例 #6
0
def test_get_atomic_FCHL():

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    struc = GNR.make_struc_dict(atoms)
    bonds = GNR.make_bonds_df(mols)

    atoms = QML.get_atomic_qml_features(atoms,
                                        bonds,
                                        struc,
                                        featureflag='FCHL')

    assert atoms['atomic_rep'].values[0].shape == (5, 50)
コード例 #7
0
def test_get_atomic_aSLATM():

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    struc = GNR.make_struc_dict(atoms)
    bonds = GNR.make_bonds_df(mols)

    atoms = QML.get_atomic_qml_features(atoms,
                                        bonds,
                                        struc,
                                        featureflag='aSLATM')

    assert len(atoms['atomic_rep'].values[0]) == 20105
コード例 #8
0
def test_enhance_bonds():

    p_table = Get_periodic_table()

    #####
    mols = dmy.get_rndethane_mols(distance=True)

    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    BCAI.enhance_atoms(atoms, structure_dict)

    bonds = GNR.make_bonds_df(mols)
    ############

    BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH')

    for idx, bond in enumerate(bonds):
        molid = bonds['molecule_name'][idx]
        at1 = bonds['atom_index_0'][idx]
        at2 = bonds['atom_index_1'][idx]

        mol = 0
        for ml_fnd in mols:
            if ml_fnd.molid == molid:
                mol = ml_fnd

        assert mol.coupling_len[at1][at2] == int(bonds['type'][idx][0])
        assert mol.coupling[at1][at2] == bonds['scalar_coupling_constant'][idx]
        if bonds['labeled_type'][idx] == '3JHH':
            assert bonds['predict'][idx] == 1
        else:
            assert bonds['predict'][idx] == 0

    for mol in mols:
        for atom1, type1 in enumerate(mol.types):
            for atom2, type2 in enumerate(mol.types):

                if atom1 == atom2:
                    continue

                row = bonds.loc[(bonds['molecule_name'] == mol.molid)
                                & (bonds['atom_index_0'] == atom1)
                                & (bonds['atom_index_1'] == atom2)]

                cpl = row['scalar_coupling_constant'].values

                if type1 == 1 and type2 == 1 and mol.coupling_len[atom1][
                        atom2] == 3:
                    assert row.predict.values == 1
                    assert mol.coupling[atom1][atom2] == cpl[0]
コード例 #9
0
def test_get_atomic_cmat():

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    struc = GNR.make_struc_dict(atoms)
    bonds = GNR.make_bonds_df(mols)

    atoms = QML.get_atomic_qml_features(atoms,
                                        bonds,
                                        struc,
                                        featureflag='CMAT')

    assert len(atoms['atomic_rep'].values[0]) == 50 * (50 + 1) / 2
    assert len(atoms['atomic_rep'].values[0].nonzero()[0]) == len(
        mols[0].types) * (len(mols[0].types) + 1) / 2
コード例 #10
0
def test_get_scaling():

    #####
    mols = dmy.get_rndethane_mols(distance=True)

    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    BCAI.enhance_atoms(atoms, structure_dict)

    bonds = GNR.make_bonds_df(mols)
    BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH')

    #####

    means, stds = BCAI.get_scaling(bonds)
コード例 #11
0
def test_enhance_structure_dict():

    p_table = Get_periodic_table()
    #####
    mols = dmy.get_rndethane_mols(distance=True)
    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    #####

    BCAI.enhance_structure_dict(structure_dict)

    for mol in mols:
        assert structure_dict[mol.molid]['typesstr'] == [
            p_table[type] for type in mol.types
        ]
        assert np.array_equal(structure_dict[mol.molid]['positions'], mol.xyz)
        assert np.array_equal(structure_dict[mol.molid]['conn'], mol.conn)
        assert np.array_equal(structure_dict[mol.molid]['distances'], mol.dist)
コード例 #12
0
def test_add_embedding():

    p_table = Get_periodic_table()

    #####
    mols = dmy.get_rndethane_mols(distance=True)

    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    BCAI.enhance_atoms(atoms, structure_dict)

    bonds = GNR.make_bonds_df(mols)
    BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH')

    triplets = BCAI.make_triplets(bonds["molecule_name"].unique(),
                                  structure_dict)
    #####

    embeddings, atoms, bonds, triplets = BCAI.add_embedding(
        atoms, bonds, triplets)
コード例 #13
0
def test_make_atom_df():

    mols = dmy.get_rndethane_mols()
    ats = 0
    for mol in mols:
        ats += len(mol.types)

    atoms = GNR.make_atom_df(mols)
    assert len(atoms["molecule_name"].unique()) == len(mols)

    counted = 0
    for i, idx in enumerate(atoms['atom_index'].values):
        for mol in mols:
            if atoms['molecule_name'][i] == mol.molid:
                counted += 1
                assert np.array_equal(mol.conn[idx], atoms['conn'][i])
                assert mol.xyz[idx][0] == atoms['x'][i]
                assert mol.xyz[idx][1] == atoms['y'][i]
                assert mol.xyz[idx][2] == atoms['z'][i]

    assert counted == ats