コード例 #1
0
def test_enhance_atoms():

    p_table = Get_periodic_table()

    #####
    mols = dmy.get_rndethane_mols(distance=True)
    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    ###########

    BCAI.enhance_atoms(atoms, structure_dict)

    for i, idx in enumerate(atoms['atom_index'].values):

        molid = atoms['molecule_name'][i]
        mol = 0
        for ml_fnd in mols:
            if ml_fnd.molid == molid:
                mol = ml_fnd

        atid = atoms['atom_index'][i]
        assert p_table.index(atoms['typestr'][i]) == mol.types[atid]
        assert np.array_equal(atoms['conn'][i], mol.conn[atid])
        assert np.array_equal(atoms['distance'][i], mol.dist[atid])
コード例 #2
0
def test_get_dummy_features():

    mol = dmy.get_ethane_mol()

    mols = [mol]

    x, y, r = GNR.get_dummy_features(mols, targetflag='CCS')
    assert x == []
    assert np.array_equal(y, mol.shift[np.where(mol.types==6)])
    assert np.array_equal(r, [['ethane', 0], ['ethane', 1]])

    x, y, r = GNR.get_dummy_features(mols, targetflag='HCS')
    assert x == []
    assert np.array_equal(y, mol.shift[np.where(mol.types==1)])
    assert np.array_equal(r, [['ethane', 2], ['ethane', 3], ['ethane', 4],
                            ['ethane', 5], ['ethane', 6], ['ethane', 7]])

    x, y, r = GNR.get_dummy_features(mols, targetflag='1JCH')
    for i, index in enumerate(r):
        assert y[i] == mol.coupling[index[1]][index[2]]
    assert np.array_equal(r, [['ethane', 0, 2], ['ethane', 0, 3], ['ethane', 0, 4],
                            ['ethane', 1, 5], ['ethane', 1, 6], ['ethane', 1, 7]])

    x, y, r = GNR.get_dummy_features(mols, targetflag='3JHH')
    for i, index in enumerate(r):
        assert y[i] == mol.coupling[index[1]][index[2]]
    assert np.array_equal(r, [['ethane', 2, 5], ['ethane', 2, 6], ['ethane', 2, 7],
                            ['ethane', 3, 5], ['ethane', 3, 6], ['ethane', 3, 7],
                            ['ethane', 4, 5], ['ethane', 4, 6], ['ethane', 4, 7]])
コード例 #3
0
def test_make_triplets():

    p_table = Get_periodic_table()

    #####
    mols = dmy.get_rndethane_mols(distance=True)

    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    BCAI.enhance_atoms(atoms, structure_dict)

    bonds = GNR.make_bonds_df(mols)
    BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH')
    ############

    triplets = BCAI.make_triplets(bonds["molecule_name"].unique(),
                                  structure_dict)
    assert len(triplets["molecule_name"].unique()) == len(mols)

    count = 0
    for mol in mols:
        for atom1, type1 in enumerate(mol.types):
            for atom2, type2 in enumerate(mol.types):
                if atom1 == atom2:
                    continue

                for atom3, type3 in enumerate(mol.types):

                    if atom3 in [atom1, atom2] or atom3 < atom2:
                        continue

                    if mol.conn[atom1][atom2] != 1 or mol.conn[atom1][
                            atom3] != 1:
                        continue

                    row = triplets.loc[(triplets.molecule_name == mol.molid)
                                       & (triplets.atom_index_0 == atom1)
                                       & (triplets.atom_index_1 == atom2)
                                       & (triplets.atom_index_2 == atom3)]

                    assert len(row.index) == 1

                    ba = mol.xyz[atom2] - mol.xyz[atom1]
                    bc = mol.xyz[atom3] - mol.xyz[atom1]

                    angle = np.sum(
                        ba * bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))
                    angle = np.arccos(np.clip(angle, -1.0, 1.0))

                    assert angle == row.angle.values
                    count += 1

    assert count == len(triplets.index)
コード例 #4
0
def test_get_atomic_ACSF():

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    struc = GNR.make_struc_dict(atoms)
    bonds = GNR.make_bonds_df(mols)

    atoms = QML.get_atomic_qml_features(atoms,
                                        bonds,
                                        struc,
                                        featureflag='ACSF')
コード例 #5
0
    def get_features_frommols(self,
                              args,
                              params={},
                              molcheck_run=False,
                              training=True,
                              max=200):

        self.params = params

        target = flag_to_target(args['targetflag'])
        self.remove_mols(target)
        if molcheck_run:
            return

        for mol in self.mols:
            if len(mol.types) > max:
                max = len(mol.types) + 1
                print('WARNING, setting max atoms to ', max)

        self.params['max'] = max

        self.atoms = GNR.make_atom_df(self.mols)
        self.struc = GNR.make_struc_dict(self.atoms)
        if len(args['targetflag']) == 4:
            self.bonds = GNR.make_bonds_df(self.mols)

        if args['featureflag'] in ['aSLATM', 'CMAT', 'FCHL', 'ACSF']:
            from autoenrich.ml.features import QML_features
            self.atoms = QML_features.get_atomic_qml_features(
                self.atoms,
                self.bonds,
                self.struc,
                featureflag=args['featureflag'],
                cutoff=params['cutoff'],
                max=max)

        elif args['featureflag'] == 'BCAI':
            from autoenrich.ml.features import TFM_features
            self.BCAI, self.atoms, self.bonds, self.struc, xfiles, rfiles, yfiles = TFM_features.get_BCAI_features(
                self.atoms,
                self.bonds,
                self.struc,
                targetflag=args['targetflag'],
                training=training)

        elif args['featureflag'] != 'dummy':
            return

        else:
            print('Feature flag not recognised, no feature flag: ',
                  args['featureflag'])
            return 0
コード例 #6
0
def test_make_struc_dict():

    p_table = Get_periodic_table()

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    assert len(structure_dict.keys()) == len(mols)

    for mol in mols:
        assert structure_dict[mol.molid]['typesstr'] == [p_table[type] for type in mol.types]
        assert np.array_equal(structure_dict[mol.molid]['positions'], mol.xyz)
        assert np.array_equal(structure_dict[mol.molid]['conn'], mol.conn)
コード例 #7
0
def test_get_atomic_FCHL():

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    struc = GNR.make_struc_dict(atoms)
    bonds = GNR.make_bonds_df(mols)

    atoms = QML.get_atomic_qml_features(atoms,
                                        bonds,
                                        struc,
                                        featureflag='FCHL')

    assert atoms['atomic_rep'].values[0].shape == (5, 50)
コード例 #8
0
def test_get_atomic_aSLATM():

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    struc = GNR.make_struc_dict(atoms)
    bonds = GNR.make_bonds_df(mols)

    atoms = QML.get_atomic_qml_features(atoms,
                                        bonds,
                                        struc,
                                        featureflag='aSLATM')

    assert len(atoms['atomic_rep'].values[0]) == 20105
コード例 #9
0
def test_enhance_bonds():

    p_table = Get_periodic_table()

    #####
    mols = dmy.get_rndethane_mols(distance=True)

    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    BCAI.enhance_atoms(atoms, structure_dict)

    bonds = GNR.make_bonds_df(mols)
    ############

    BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH')

    for idx, bond in enumerate(bonds):
        molid = bonds['molecule_name'][idx]
        at1 = bonds['atom_index_0'][idx]
        at2 = bonds['atom_index_1'][idx]

        mol = 0
        for ml_fnd in mols:
            if ml_fnd.molid == molid:
                mol = ml_fnd

        assert mol.coupling_len[at1][at2] == int(bonds['type'][idx][0])
        assert mol.coupling[at1][at2] == bonds['scalar_coupling_constant'][idx]
        if bonds['labeled_type'][idx] == '3JHH':
            assert bonds['predict'][idx] == 1
        else:
            assert bonds['predict'][idx] == 0

    for mol in mols:
        for atom1, type1 in enumerate(mol.types):
            for atom2, type2 in enumerate(mol.types):

                if atom1 == atom2:
                    continue

                row = bonds.loc[(bonds['molecule_name'] == mol.molid)
                                & (bonds['atom_index_0'] == atom1)
                                & (bonds['atom_index_1'] == atom2)]

                cpl = row['scalar_coupling_constant'].values

                if type1 == 1 and type2 == 1 and mol.coupling_len[atom1][
                        atom2] == 3:
                    assert row.predict.values == 1
                    assert mol.coupling[atom1][atom2] == cpl[0]
コード例 #10
0
def test_get_atomic_cmat():

    mols = dmy.get_rndethane_mols()
    atoms = GNR.make_atom_df(mols)
    struc = GNR.make_struc_dict(atoms)
    bonds = GNR.make_bonds_df(mols)

    atoms = QML.get_atomic_qml_features(atoms,
                                        bonds,
                                        struc,
                                        featureflag='CMAT')

    assert len(atoms['atomic_rep'].values[0]) == 50 * (50 + 1) / 2
    assert len(atoms['atomic_rep'].values[0].nonzero()[0]) == len(
        mols[0].types) * (len(mols[0].types) + 1) / 2
コード例 #11
0
def test_get_scaling():

    #####
    mols = dmy.get_rndethane_mols(distance=True)

    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    BCAI.enhance_atoms(atoms, structure_dict)

    bonds = GNR.make_bonds_df(mols)
    BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH')

    #####

    means, stds = BCAI.get_scaling(bonds)
コード例 #12
0
def test_make_bonds_df():

    p_table = Get_periodic_table()

    #####
    mols = dmy.get_rndethane_mols()
    #####

    bonds = GNR.make_bonds_df(mols)
    assert len(bonds["molecule_name"].unique()) == len(mols)

    for idx, bond in enumerate(bonds):

        molid = bonds['molecule_name'][idx]

        at1 = bonds['atom_index_0'][idx]
        at2 = bonds['atom_index_1'][idx]

        mol = 0
        for ml_fnd in mols:
            if ml_fnd.molid == molid:
                mol = ml_fnd

        assert mol.coupling_len[at1][at2] == int(bonds['type'][idx][0])
        assert mol.coupling[at1][at2] == bonds['scalar_coupling_constant'][idx]
コード例 #13
0
def test_enhance_structure_dict():

    p_table = Get_periodic_table()
    #####
    mols = dmy.get_rndethane_mols(distance=True)
    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    #####

    BCAI.enhance_structure_dict(structure_dict)

    for mol in mols:
        assert structure_dict[mol.molid]['typesstr'] == [
            p_table[type] for type in mol.types
        ]
        assert np.array_equal(structure_dict[mol.molid]['positions'], mol.xyz)
        assert np.array_equal(structure_dict[mol.molid]['conn'], mol.conn)
        assert np.array_equal(structure_dict[mol.molid]['distances'], mol.dist)
コード例 #14
0
def test_add_embedding():

    p_table = Get_periodic_table()

    #####
    mols = dmy.get_rndethane_mols(distance=True)

    atoms = GNR.make_atom_df(mols)
    structure_dict = GNR.make_struc_dict(atoms)
    BCAI.enhance_structure_dict(structure_dict)
    BCAI.enhance_atoms(atoms, structure_dict)

    bonds = GNR.make_bonds_df(mols)
    BCAI.enhance_bonds(bonds, structure_dict, flag='3JHH')

    triplets = BCAI.make_triplets(bonds["molecule_name"].unique(),
                                  structure_dict)
    #####

    embeddings, atoms, bonds, triplets = BCAI.add_embedding(
        atoms, bonds, triplets)
コード例 #15
0
def test_make_atom_df():

    mols = dmy.get_rndethane_mols()
    ats = 0
    for mol in mols:
        ats += len(mol.types)

    atoms = GNR.make_atom_df(mols)
    assert len(atoms["molecule_name"].unique()) == len(mols)

    counted = 0
    for i, idx in enumerate(atoms['atom_index'].values):
        for mol in mols:
            if atoms['molecule_name'][i] == mol.molid:
                counted += 1
                assert np.array_equal(mol.conn[idx], atoms['conn'][i])
                assert mol.xyz[idx][0] == atoms['x'][i]
                assert mol.xyz[idx][1] == atoms['y'][i]
                assert mol.xyz[idx][2] == atoms['z'][i]

    assert counted == ats
コード例 #16
0
    def get_features_frommols(self,
                              args,
                              params={},
                              molcheck_run=False,
                              training=True):

        featureflag = args['featureflag']
        targetflag = args['targetflag']
        try:
            max = args['max_size']
        except:
            max = 200

        for mol in self.mols:
            if len(mol.types) > max:
                max = len(mol.types)
                print('WARNING, SETTING MAXIMUM MOLECULE SIZE TO, ', max)

        if 'cutoff' in params:
            if params['cutoff'] < 0.1:
                params['cutoff'] = 0.1
        else:
            params['cutoff'] = 5.0

        x = []
        y = []
        r = []

        self.params = params

        target = flag_to_target(targetflag)
        self.remove_mols(target)
        if molcheck_run:
            return

        if featureflag in ['aSLATM', 'CMAT', 'FCHL', 'ACSF']:
            import qml
        elif featureflag in ['BCAI']:
            from autoenrich.ml.features import TFM_features

        _, y, r = GNR_features.get_dummy_features(self.mols, targetflag)

        if featureflag == 'aSLATM':
            mbtypes = [[1], [1, 1], [1, 1, 1], [1, 1, 6], [1, 1, 7], [1, 1, 8],
                       [1, 1, 9], [1, 6], [1, 6, 1], [1, 6, 6], [1, 6, 7],
                       [1, 6, 8], [1, 6, 9], [1, 7], [1, 7, 1], [1, 7, 6],
                       [1, 7, 7], [1, 7, 8], [1, 7, 9], [1, 8], [1, 8, 1],
                       [1, 8, 6], [1, 8, 7], [1, 8, 8], [1, 8, 9], [1, 9],
                       [1, 9, 1], [1, 9, 6], [1, 9, 7], [1, 9, 8], [1, 9, 9],
                       [6], [6, 1], [6, 1, 1], [6, 1, 6], [6, 1, 7], [6, 1, 8],
                       [6, 1, 9], [6, 6], [6, 6, 1], [6, 6, 6], [6, 6, 7],
                       [6, 6, 8], [6, 6, 9], [6, 7], [6, 7, 1], [6, 7, 6],
                       [6, 7, 7], [6, 7, 8], [6, 7, 9], [6, 8], [6, 8, 1],
                       [6, 8, 6], [6, 8, 7], [6, 8, 8], [6, 8, 9], [6, 9],
                       [6, 9, 1], [6, 9, 6], [6, 9, 7], [6, 9, 8], [6, 9, 9],
                       [7], [7, 1], [7, 1, 1], [7, 1, 6], [7, 1, 7], [7, 1, 8],
                       [7, 1, 9], [7, 6], [7, 6, 1], [7, 6, 6], [7, 6, 7],
                       [7, 6, 8], [7, 6, 9], [7, 7], [7, 7, 1], [7, 7, 6],
                       [7, 7, 7], [7, 7, 8], [7, 7, 9], [7, 8], [7, 8, 1],
                       [7, 8, 6], [7, 8, 7], [7, 8, 8], [7, 8, 9], [7, 9],
                       [7, 9, 1], [7, 9, 6], [7, 9, 7], [7, 9, 8], [7, 9, 9],
                       [8], [8, 1], [8, 1, 1], [8, 1, 6], [8, 1, 7], [8, 1, 8],
                       [8, 1, 9], [8, 6], [8, 6, 1], [8, 6, 6], [8, 6, 7],
                       [8, 6, 8], [8, 6, 9], [8, 7], [8, 7, 1], [8, 7, 6],
                       [8, 7, 7], [8, 7, 8], [8, 7, 9], [8, 8], [8, 8, 1],
                       [8, 8, 6], [8, 8, 7], [8, 8, 8], [8, 8, 9], [8, 9],
                       [8, 9, 1], [8, 9, 6], [8, 9, 7], [8, 9, 8], [8, 9, 9],
                       [9], [9, 1], [9, 1, 1], [9, 1, 6], [9, 1, 7], [9, 1, 8],
                       [9, 1, 9], [9, 6], [9, 6, 1], [9, 6, 6], [9, 6, 7],
                       [9, 6, 8], [9, 6, 9], [9, 7], [9, 7, 1], [9, 7, 6],
                       [9, 7, 7], [9, 7, 8], [9, 7, 9], [9, 8], [9, 8, 1],
                       [9, 8, 6], [9, 8, 7], [9, 8, 8], [9, 8, 9], [9, 9],
                       [9, 9, 1], [9, 9, 6], [9, 9, 7], [9, 9, 8], [9, 9, 9]]
            '''
			nuclear_charges = []
			for tmp_mol in mols:
				nuclear_charges.append(tmp_mol.types)
			mbtypes = qml.representations.get_slatm_mbtypes(nuclear_charges)
			'''
            reps = qml.representations.generate_slatm(mol.xyz,
                                                      mol.types,
                                                      mbtypes,
                                                      rcut=cutoff)
            x = np.asarray(reps)

        elif featureflag == 'CMAT':
            reps = qml.representations.generate_atomic_coulomb_matrix(
                mol.types, mol.xyz, size=max, central_cutoff=cutoff)
            x = np.asarray(reps)

        elif featureflag == 'FCHL':
            reps = qml.fchl.generate_representation(mol.xyz,
                                                    mol.types,
                                                    max,
                                                    cut_distance=cutoff)
            x = np.asarray(reps)

        elif featureflag == 'ACSF':
            reps = qml.representations.generate_acsf(
                mol.types,
                mol.xyz,
                elements=[1, 6, 7, 8, 9, 14, 15, 16, 17, 35],
                nRs2=int(nRs2),
                nRs3=int(nRs3),
                nTs=int(nTs),
                eta2=eta2,
                eta3=eta3,
                zeta=zeta,
                rcut=cutoff,
                acut=acut,
                bin_min=0.0,
                gradients=False)
            x = np.asarray(reps)

        elif featureflag == 'BCAI':

            _x, _y, _r, mol_order = TFM_features.get_BCAI_features(
                self.mols, targetflag, training=training)

            x.extend(_x)
            y.extend(_y)
            r.extend(_r)
            batch_mols = []

        else:
            print('Feature flag not recognised, no feature flag: ',
                  featureflag)

        if featureflag == 'BCAI':
            self.x = x
            self.y = y
            self.r = r
            self.mol_order = mol_order
        else:
            self.x = np.asarray(x)
            self.y = np.asarray(y)
            self.r = r

        if featureflag not in ['dummy', 'BCAI']:
            print('Reps generated, shape: ', self.x.shape)