Ejemplo n.º 1
0
def test_get_mols():

    files = glob.glob('tests/test_store/dataset/eth_rnd_mol_*.nmredata.sdf')
    dset = dataset()
    dset.get_mols(files)

    assert dset.files == files
    assert dset.big_data == False
    assert len(dset.mols) == 5
    for mol in dset.mols:
        assert mol_is_ethane(mol)
Ejemplo n.º 2
0
def test_dataset():

    dset = dataset()

    assert dset.mols == []
    assert dset.x == []
    assert dset.y == []
    assert dset.r == []
    assert dset.mol_order == []
    assert dset.big_data == False
    assert dset.files == []
    assert dset.params == {}
Ejemplo n.º 3
0
def get_test_dataset(size=5):

    mols = []
    for i in range(size):
        mol = dummymol.get_random_ethane()
        mol.molid = 'random' + str(i)
        #mol.get_distance_matrix(heavy_only=False)
        mols.append(mol)

    dset = dataset()
    dset.mols = mols

    return dset
Ejemplo n.º 4
0
def get_dummy_dataset(ml_size=10, at_size=10, target=[1]):

    dset = dataset()

    for i in range(ml_size):
        dset.mols.append(dummymol.get_random_mol(size=at_size))
        dset.files.append('file_', str(i))

    x = []
    y = []
    r = []

    for mol in mols:
        if len(target) == 1:
            for i in range(len(mol.types)):
                if mol.types[i] == target[0]:
                    y.append(mol.shift[i])
                    r.append([mol.molid, i])

        if len(target) == 3:
            for i in range(len(mol.types)):
                for j in range(len(mol.types)):
                    if i == j:
                        continue

                    if not (mol.types[i] == target[1]
                            and mol.types[j] == target[2]):
                        continue

                    if mol.coupling_len[i][j] != target[0]:
                        continue

                    y.append(mol.coupling[i][j])
                    r.append([mol.molid, i, j])

    dset.x = x
    dset.y = y
    dset.r = r

    return dset
Ejemplo n.º 5
0
def predict(args):

	from autoenrich.molecule.dataset import dataset
	from autoenrich.file_creation.structure_formats.nmredata import nmrmol_to_nmredata

	for files_set in args['test_sets']:
		parts = files_set.split('/')
		path = ''
		for part in parts[:-1]:
			path = path + part + '/'

		files = glob.glob(files_set)
		#if len(files) == 0:
		#	print ('No file(s) found matching ', args['training_set'])
		#	sys.exit(0)
		dset = dataset()

		label_part = get_unique_part(files)
		dset.get_mols(files, type='nmredata', label_part=label_part)
		if len(dset.mols) == 0:
			print('No molecules loaded. . .')
			sys.exit(0)

		for m, model_file in enumerate(args['models']):

			print('Predicting from model: ', model_file)

			model = pickle.load(open(model_file, 'rb'))

			print(model.args["targetflag"])
			dset.get_features_frommols(model.args, params=model.params, training=False)
			assert len(dset.x) > 0, print('No features made. . . ')

			if args['store_datasets']:
				pickle.dump(dset, open('OPT_testing_set.pkl', 'wb'))

			y_test, y_pred = model.predict(dset.x[0])

			v_preds = []
			for i in range(args['var']):
				var_model_file = model_file.split('.pkl')[0] + '_' + str(i+1) + '.pkl'

				try:
					var_model = pickle.load(open(var_model_file, 'rb'))
				except Exception as e:
					print(e)
					continue

				assert model.args['featureflag'] == var_model.args['featureflag']
				assert model.args['targetflag'] == var_model.args['targetflag']
				assert model.args['max_size'] == var_model.args['max_size']
				assert model.params == var_model.params, print(model.params, var_model.params)

				print('\tPredicting from ', var_model_file)
				tmp_preds = var_model.predict(dset.x)
				v_preds.append(tmp_preds)

			if args['var'] > 0:
				var = np.var(np.asarray(v_preds), axis=0)
			else:
				var = np.zeros(len(y_pred), dtype=np.float64)

			if m == 0:
				dset.assign_from_ml(y_pred, var, zero=True)
			else:
				dset.assign_from_ml(y_pred, var, zero=False)

		for mol in dset.mols:
			outname = args['output_dir'] + 'IMP_' + mol.molid + '.nmredata.sdf'
			nmrmol_to_nmredata(mol, outname)

	print('Done')
Ejemplo n.º 6
0
def compare_datasets(args):

    att_mols = []
    sets = []
    for set_list in args['comp_sets']:
        print('Getting molecules from ', set_list)
        set = dataset()
        label_part = get_unique_part(glob.glob(set_list))
        set.get_mols(glob.glob(set_list), label_part=label_part)
        print(len(set.mols), ' molecules found from ',
              len(glob.glob(set_list)), ' files')

        sets.append(set)

    assert len(sets) > 1, print('Only one set found. . .')
    assert len(sets[0].mols) == len(
        sets[1].mols), print('Different numbers of molecules in sets')

    found = []
    for m1, mol1 in enumerate(sets[0].mols):
        if m1 in found:
            continue
        for m2, mol2 in enumerate(sets[1].mols):

            if args['match_criteria'] == 'id':
                if mol1.molid == mol2.molid:
                    att_mols.append([mol1, mol2])
                else:
                    continue

            if not mol_isequal(mol1, mol2):
                continue

            if [mol1, mol2] in att_mols:
                continue

            if len(sets) > 2:
                for m3, mol3 in enumerate(sets[2].mols):
                    if not mol_isequal(mol1, mol3):
                        continue

                    if [mol1, mol2, mol3] in att_mols:
                        continue

                    att_mols.append([mol1, mol2, mol3])

            else:
                found.append(m1)
                att_mols.append([mol1, mol2])

    print(len(att_mols), ' molecules matched, out of ', len(sets[0].mols))

    for targetflag in args['comp_targets']:
        print(targetflag)
        target = flag_to_target(targetflag)

        for set in sets:
            set.get_features_frommols({
                'featureflag': 'dummy',
                'targetflag': targetflag,
                'max_size': 0
            })

        values = []
        refs = []
        typerefs = []

        assert len(sets[0].r) == len(sets[1].r)
        if len(sets) > 2:
            assert len(sets[2].r) == len(sets[1].r)

        for group in att_mols:
            for i in range(len(sets[0].r)):
                ref1 = sets[0].r[i]
                if ref1[0] != group[0].molid:
                    continue
                val1 = sets[0].y[i]
                typeref1 = [group[0].types[row] for row in ref1[1:]]

                for j in range(len(sets[0].r)):

                    ref2 = sets[1].r[j]

                    if ref2[0] != group[1].molid:
                        continue
                    val2 = sets[1].y[j]
                    typeref2 = [group[1].types[row] for row in ref2[1:]]

                    bad = False
                    for xx in range(1, len(ref1)):
                        if ref1[xx] != ref2[xx]:
                            bad = True
                    if typeref1 != typeref2:
                        #print(typeref1, typeref2)
                        bad = True
                    if bad:
                        continue

                    if len(sets) > 2:
                        for k in range(len(sets[0].r)):
                            ref3 = sets[2].r[k]
                            if ref3[0] != group[2].molid:
                                continue
                            val3 = sets[2].y[k]
                            typeref3 = [
                                group[2].types[row] for row in ref3[1:]
                            ]

                            bad = False
                            for xx in range(1, len(ref1)):
                                if ref1[xx] != ref3[xx]:
                                    bad = True
                            if bad:
                                continue

                            refs.append([ref1, ref2, ref3])
                            values.append([val1, val2, val3])
                            typerefs.append([typeref1, typeref2, typeref3])
                    else:
                        refs.append([ref1, ref2])
                        values.append([val1, val2])
                        typerefs.append([typeref1, typeref2])

        if 'output_path' in args:
            assert len(args['output_path']) != 0
            outname = args['output_path'] + '/Comparison_' + str(
                targetflag) + '.csv'
        else:
            outname = 'Comparison_' + str(targetflag) + '.csv'

        print_mol_csv(outname, refs, typerefs, values, args['comp_labels'])

        x = [row[0] for row in values]
        y = [row[1] for row in values]

        MAE = np.mean(np.absolute(np.asarray(x) - np.asarray(y)))
        MAEstring = '{0:<6.3f}'.format(MAE)
        print('MAE between ', args['comp_labels'][0], args['comp_labels'][1],
              ' = ', MAEstring, '   no. of envs. ', len(x))
Ejemplo n.º 7
0
def trainmodel(args):

	from autoenrich.ml.HPS import HPS
	from autoenrich.molecule.dataset import dataset
	from autoenrich.molecule.nmrmol import nmrmol

	params = []
	for param in args['param_ranges'].keys():
		params.append(param)
	args['param_list'] = params

	args['BCAI_load'] = 'True'

	if args['featureflag'] == 'BCAI':

		if args['training_set'].split('.')[-1] == 'pkl':
			dset = pickle.load(open(args['training_set'], 'rb'))
		else:
			files = glob.glob(args['training_set'])
			dset = dataset()
			dset.get_mols(files, type='nmredata')

		print('Number of molecules in training set: ', len(dset.mols))

		if args['BCAI_load'] != 'True' or len(dset.x) == 0:
			dset.get_features_frommols(args, params={}, training=True)

		with open('training_data/BCAI_dataset.pkl', "wb") as f:
			pickle.dump(dset, f)

		dset, score = HPS(dset, args)

		print('Final Score: ', score)

	else:
		if args['feature_optimisation'] == 'True':
			if args['training_set'].split('.')[-1] == 'pkl':
				dset = pickle.load(open(args['training_set'], 'rb'))
			elif args['training_set'].split('.')[-1] == 'csv':
				dset = load_dataset_from_csv(args['training_set'])
			else:
				args['load_dataset'] = 'false'
				if args['load_dataset'] == 'true':
					dset = pickle.load(open('training_data/empty_dataset.pkl', 'rb'))
				else:
					files = glob.glob(args['training_set'])
					dset = dataset()
					dset.get_mols(files, type='nmredata')
					assert len(dset.mols) > 0


			dset, score = HPS(dset, args)

			if args['store_datasets'] == 'True':
				pickle.dump(dset, open('OPT_training_set.pkl', 'wb'))


		else:
			dset = pickle.load(open(args['feature_file'], "rb"))
			assert len(dset.x) > 0
			assert len(dset.y) > 0

		HPS(dset, args)