Пример #1
0
def get_BCAI_features(atoms, bonds, struc, targetflag='CCS', training=True):

	target = flag_to_target(targetflag)

	BCAI.enhance_structure_dict(structure_dict)
	BCAI.enhance_atoms(atoms, structure_dict)
	bonds = BCAI.enhance_bonds(bonds, structure_dict)

	triplets = BCAI.make_triplets(bonds["molecule_name"].unique(), structure_dict)

	atoms = pd.DataFrame(atoms)
	bonds = pd.DataFrame(bonds)
	triplets = pd.DataFrame(triplets)

	atoms.sort_values(['molecule_name','atom_index'],inplace=True)
	bonds.sort_values(['molecule_name','atom_index_0','atom_index_1'],inplace=True)
	triplets.sort_values(['molecule_name','atom_index_0','atom_index_1','atom_index_2'],inplace=True)

	embeddings, atoms, bonds, triplets = BCAI.add_embedding(atoms, bonds, triplets)
	bonds.dropna()
	atoms.dropna()
	means, stds = BCAI.get_scaling(bonds)
	bonds = BCAI.add_scaling(bonds, means, stds)

	Dset = BCAI.create_dataset(atoms, bonds, triplets, labeled = True, max_count = 10**10, mol_order=mol_order)

	if training:
		x, y, r, mol_order = save_split_dataset(Dset)
	else:
		x, y, r, mol_order = save_dataset(Dset)

	return Dset, atoms, bonds, struc, x, y, r
Пример #2
0
    def get_features_frommols(self,
                              args,
                              params={},
                              molcheck_run=False,
                              training=True,
                              max=200):

        self.params = params

        target = flag_to_target(args['targetflag'])
        self.remove_mols(target)
        if molcheck_run:
            return

        for mol in self.mols:
            if len(mol.types) > max:
                max = len(mol.types) + 1
                print('WARNING, setting max atoms to ', max)

        self.params['max'] = max

        self.atoms = GNR.make_atom_df(self.mols)
        self.struc = GNR.make_struc_dict(self.atoms)
        if len(args['targetflag']) == 4:
            self.bonds = GNR.make_bonds_df(self.mols)

        if args['featureflag'] in ['aSLATM', 'CMAT', 'FCHL', 'ACSF']:
            from autoenrich.ml.features import QML_features
            self.atoms = QML_features.get_atomic_qml_features(
                self.atoms,
                self.bonds,
                self.struc,
                featureflag=args['featureflag'],
                cutoff=params['cutoff'],
                max=max)

        elif args['featureflag'] == 'BCAI':
            from autoenrich.ml.features import TFM_features
            self.BCAI, self.atoms, self.bonds, self.struc, xfiles, rfiles, yfiles = TFM_features.get_BCAI_features(
                self.atoms,
                self.bonds,
                self.struc,
                targetflag=args['targetflag'],
                training=training)

        elif args['featureflag'] != 'dummy':
            return

        else:
            print('Feature flag not recognised, no feature flag: ',
                  args['featureflag'])
            return 0
Пример #3
0
def get_dummy_features(mols, targetflag='CCS'):
    # Input:
    #	mols: list of autoenrich nmrmol objects
    #	targetflag: flag corresponding to nmr parameter (string)

    # Returns:
    #	x: empty list (where features would normally go)
    #	y: 1D list of NMR parameters
    #	r: 1D list of NMR parameter references (molid, atom(1, atom2))

    target = flag_to_target(targetflag)

    x = []
    y = []
    r = []

    for mol in mols:
        if len(target) == 1:
            for i in range(len(mol.types)):
                if mol.types[i] == target[0]:
                    y.append(mol.shift[i])
                    r.append([mol.molid, i])

        if len(target) == 3:
            for i in range(len(mol.types)):
                for j in range(i + 1, len(mol.types)):

                    if not (mol.types[i] == target[1]
                            and mol.types[j] == target[2]):
                        continue

                    if mol.coupling_len[i][j] != target[0]:
                        continue

                    y.append(mol.coupling[i][j])
                    r.append([mol.molid, i, j])

    return x, y, r
Пример #4
0
def test_predictFCHLmodel():

    dset = get_test_dataset()

    for target in ['HCS', '1JCH']:
        args = {'featureflag': 'FCHL',
                'targetflag': target}
        params = {'cutoff': 5.0}

        dset.get_features_frommols(args, params)
        model = FCHLmodel()

        target = flag_to_target(args['targetflag'])

        model.get_x(dset, args['targetflag'], assign_train=True)
        model.train()

        assert model.trained == True
        train_x, train_y = model.get_x(dset, args['targetflag'], assign_train=False)
        test_x = np.asarray(train_x[:, :4])
        y_pred = model.predict(test_x)

        assert np.allclose(y_pred, train_y[:4], atol=0.1, rtol=0.0)
        assert sum(y_pred) != 0.0
Пример #5
0
def run_wizard(args, default=False):

    # wizard section for training commands
    if args['Command'] == 'setup_train' or args['Command'] == 'train':
        # Training set ##############################################################################
        check = False
        # Repeat input loop until succesful
        while not check:
            # Check currently held args (user input or default)
            # If single file specified:
            if args['training_set'].split('.')[-1] == 'pkl' or args[
                    'training_set'].split('.')[-1] == 'csv':
                # Attempt to open the file
                try:
                    a = open(args['training_set'], 'r')
                    check = True
                except Exception as e:
                    print(e)
            # Else try and open first file in list of files
            else:
                files = glob.glob(args['training_set'])
                try:
                    a = open(files[0], 'r')
                    check = True
                except Exception as e:
                    print(e)
            # If check remains false here, prompt user for training set input
            if not check:
                args['training_set'] = input("Training set: \n")

        # Store datasets ##############################################################################
        check = False
        while not check:
            # default to storing datasets
            if default:
                args['store_datasets'] = 'True'
                check = True
            # Get user decision y/n
            else:
                decision = input(
                    "Do you want to store dataset as pickle after creation ? [y]/n: \n"
                )
                # Go through several reasonable input options
                if len(decision) == 0:
                    args['store_datasets'] = 'True'
                    check = True
                elif decision[0] in ['n', 'N']:
                    args['store_datasets'] = 'False'
                    check = True
                elif decision[0] in ['y', 'Y']:
                    args['store_datasets'] = 'True'
                    check = True
                else:
                    check = False

        # Assign default model based on feature selection
        if default:
            if args['featureflag'] == 'FCHL':
                args['modelflag'] = 'FCHL'
            elif args['featureflag'] in ['aSLATM', 'CCS', 'ACSF']:
                args['modelflag'] = 'KRR'
            else:
                args['modelflag'] == 'NN'
        # Alternatively get model from user
        else:
            combocheck = False
            while not combocheck:
                # Model ##############################################################################
                check = False
                while not check:
                    model = input(
                        "What type of model do you want ? KRR, FCHL, NN, TFM : \n"
                    )
                    if model in ['KRR', 'FCHL', 'NN', 'TFM']:
                        args['modelflag'] = model
                        check = True
                # Only one available option for FCHL model, must be FCHL feature
                if model == 'FCHL':
                    args['featureflag'] = 'FCHL'
                    check = True
                    # No need to check combination
                    combocheck = True
                # Ask for user input featureflag
                else:
                    # Feature ##############################################################################
                    check = False
                    while not check:
                        feature = input(
                            "What type of input features do you want ? CMAT, aSLATM, FCHL, ACSF, BCAI : \n"
                        )
                        if feature in [
                                'CMAT', 'aSLATM', 'FCHL', 'ACSF', 'BCAI'
                        ]:
                            args['featureflag'] = feature
                            check = True
                        else:
                            print('Requested feature [', feature,
                                  '] not recognised. . .')

                    # Check for model and feature combination, only some actually work
                    combocheck = flag_combos.check_combination(
                        args['modelflag'], args['featureflag'])
                    if not combocheck:
                        print(
                            'Invalid combination of model and feature, try again . . .'
                        )

        # Target ##############################################################################
        check = False
        while not check:
            if default:
                if args['target_list'] == '':
                    args['target_list'] = ['HCS', 'CCS', '1JCH']

            else:
                target_list = input(
                    "What target parameter(s) are you interested in ? XCS or nJXY, e.g. HCS CCS 1JCH : \n"
                )
                args['target_list'] = target_list.split()
                print(args['target_list'])

                if type(args['target_list']) != list:
                    print(args['target_list'], 'Not a list. . .')
                    check = False
                    continue

            for target in args['target_list']:
                param = hdl_targetflag.flag_to_target(target)
                if param == 0:
                    print('Invalid parameter flag')
                    check = False
                else:
                    args['targetflag'] = args['target_list'][0]
                    check = True

        # Search method ##############################################################################
        check = False
        while not check:
            if default:
                if args['searchflag'] == '':
                    args['searchflag'] = 'gaussian'
                check = True
            else:
                searchmethod = input(
                    "What search method should be used ? grid, gaussian, random : \n"
                )
                if searchmethod in ['grid', 'gaussian', 'random']:
                    args['searchflag'] = searchmethod
                    check = True

        # Feature optimisation ##############################################################################
        check = False
        while not check:
            if default:
                args['feature_optimisation'] = 'True'
                check = True

            else:
                feature_opt = input(
                    "Do you want to include feature parameters in optimisation ? : [y]/n \n"
                )
                if len(feature_opt) == 0 or feature_opt[0] in ['Y', 'y']:
                    args['feature_optimisation'] = 'True'
                    check = True
                elif feature_opt[0] in ['N', 'n']:
                    args['feature_optimisation'] = 'False'
                    check = True

            # Feature file ##############################################################################
            if args['feature_optimisation'] == 'False':
                check = False
                while not check:
                    file = input(
                        "File containing pre-made features dataset object: \n")
                    try:
                        a = open(file, 'r')
                        check = True
                    except Exception as e:
                        print(e)

        args['param_ranges'], args[
            'param_logs'] = paramdict.construct_param_dict(
                args['modelflag'], args['featureflag'], args['targetflag'])

        # Parameters ##############################################################################
        """
		if not default:
			for param in args['param_ranges'].keys():
				check = False
				IP = input("Optimise {param:<10s} ? (y)/n\n".format(param=param))
				if len(IP) == 0:
					IP = 'y'

				if IP[0] in ['n', 'N']:
					args['param_logs'][param] = 'no'
					check = True

				elif IP[0] in ['y', 'Y']:

					IP = input("Select range for parameter (min, max, log) {param:<10s}: default = {min:<10f}, {max:<10f}, {log:<10s} \n".format(param=param,
																														min=args['param_ranges'][param][0],
																														max=args['param_ranges'][param][1],
																														log=args['param_logs'][param]))
					if len(IP) == 0:
						check = True
					else:
						try:
							range = [float(IP.split(',')[0]), float(IP.split(',')[1])]
							log = IP.split(',')[2]

							args['param_ranges'][param] = range
							args['param_logs'][param] = log

							check = True

						except Exception as e:
							print(e)
		"""

        # grid density ##############################################################################
        check = False
        while not check:

            if default:
                args['cv_steps'] = int(args['cv_steps'])
                check = True
            else:
                try:
                    cv = input(
                        "Specify number of cross validation iterations: default = {0:<10f} \n"
                        .format(args['cv_steps']))
                    if len(cv) == 0:
                        check = True
                    else:
                        args['cv_steps'] = int(cv)
                        check = True
                except Exception as e:
                    print(e)

        if default:
            args['epochs'] = int(args['epochs'])
        else:
            if args['searchflag'] == 'grid':
                # grid density ##############################################################################
                check = False
                while not check:
                    try:
                        grid = input(
                            "Specify grid density for parameters: default = {0:<10f} \n"
                            .format(args['grid_density']))
                        if len(grid) == 0:
                            check = True
                        else:
                            args['grid_density'] = int(grid)
                            check = True
                    except Exception as e:
                        print(e)

            elif args['searchflag'] in ['random', 'gaussian']:
                # epochs ##############################################################################
                check = False
                while not check:
                    try:
                        epochs = input(
                            "Specify number of epochs to run: default = {0:<10f} \n"
                            .format(args['epochs']))
                        if len(epochs) == 0:
                            check = True
                        else:
                            args['epochs'] = float(epochs)
                            check = True
                    except Exception as e:
                        print(e)

            # kappa ##############################################################################
            if not default:
                check = False
                while not check:
                    try:
                        kappa = input(
                            "Specify kappa value: default = {0:<10f} \n".
                            format(args['kappa']))
                        if len(kappa) == 0:
                            check = True
                        else:
                            args['kappa'] = float(kappa)
                            check = True
                    except Exception as e:
                        print(e)

            # xi ##############################################################################
            if not default:
                check = False
                while not check:
                    try:
                        xi = input(
                            "Specify xi value: default = {0:<10f} \n".format(
                                args['xi']))
                        if len(xi) == 0:
                            check = True
                        else:
                            args['xi'] = float(xi)
                            check = True
                    except Exception as e:
                        print(e)
            # random ##############################################################################
            if not default:
                check = False
                while not check:
                    try:
                        random = input(
                            "Specify frequency of random samples: default = {0:<10d} \n"
                            .format(args['random']))
                        if len(random) == 0:
                            check = True
                        else:
                            args['random'] = int(random)
                            check = True
                    except Exception as e:
                        print(e)

    elif args['Command'] == 'setup_predict' or args['Command'] == 'predict':
        # Model(s) ##############################################################################
        check = False
        while not check:
            print(args['models'], '')
            if args['models'] != '':
                try:
                    if type(args['models']) != list:
                        args['models'] = args['models'].split()
                    for model in args['models']:
                        a = open(model, 'r')
                    check = True
                except Exception as e:
                    print(e)
                    check = False

            if check == False:
                models = input("Specify models to make predictions from: \n")
                args['models'] = models.split()

        # var model(s) ##############################################################################
        check = False
        if not default:
            while not check:
                var = input(
                    "How many models are used for pre-prediction variance ? Default=0\n variance models need to be of the format <model_file_name>_n.pkl\n"
                )
                if len(var) == 0:
                    args['var'] = 0
                    check = True
                else:
                    try:
                        args['var'] = int(var)
                        check = True
                    except Exception as e:
                        print(e)

        # input_datasets ##############################################################################
        check = False
        while not check:
            try:
                if type(args['test_sets']) != list:
                    args['test_sets'] = args['test_sets'].split()

                for tset in args['test_sets']:
                    if '*' in tset:
                        files = glob.glob(tset)
                        a = open(files[0], 'r')
                    else:
                        a = open(tset, 'r')

                check = True
            except Exception as e:
                print(e)
                check = False

            if not check:
                testsets = input("Specify set(s) of molecules to predict\n")
                args['test_sets'] = testsets.split()

    # output directory ###################################################################
    if not default:
        check = False
        while not check:

            output_dir = input(
                "Set output directory ? default is current directory \n")
            if len(output_dir) == 0:
                output_dir = './'
            elif output_dir[-1] != '/':
                output_dir = output_dir + '/'

            if os.path.isdir(output_dir):
                args['output_dir'] = output_dir
                check = True

    return args
Пример #6
0
def compare_datasets(args):

    att_mols = []
    sets = []
    for set_list in args['comp_sets']:
        print('Getting molecules from ', set_list)
        set = dataset()
        label_part = get_unique_part(glob.glob(set_list))
        set.get_mols(glob.glob(set_list), label_part=label_part)
        print(len(set.mols), ' molecules found from ',
              len(glob.glob(set_list)), ' files')

        sets.append(set)

    assert len(sets) > 1, print('Only one set found. . .')
    assert len(sets[0].mols) == len(
        sets[1].mols), print('Different numbers of molecules in sets')

    found = []
    for m1, mol1 in enumerate(sets[0].mols):
        if m1 in found:
            continue
        for m2, mol2 in enumerate(sets[1].mols):

            if args['match_criteria'] == 'id':
                if mol1.molid == mol2.molid:
                    att_mols.append([mol1, mol2])
                else:
                    continue

            if not mol_isequal(mol1, mol2):
                continue

            if [mol1, mol2] in att_mols:
                continue

            if len(sets) > 2:
                for m3, mol3 in enumerate(sets[2].mols):
                    if not mol_isequal(mol1, mol3):
                        continue

                    if [mol1, mol2, mol3] in att_mols:
                        continue

                    att_mols.append([mol1, mol2, mol3])

            else:
                found.append(m1)
                att_mols.append([mol1, mol2])

    print(len(att_mols), ' molecules matched, out of ', len(sets[0].mols))

    for targetflag in args['comp_targets']:
        print(targetflag)
        target = flag_to_target(targetflag)

        for set in sets:
            set.get_features_frommols({
                'featureflag': 'dummy',
                'targetflag': targetflag,
                'max_size': 0
            })

        values = []
        refs = []
        typerefs = []

        assert len(sets[0].r) == len(sets[1].r)
        if len(sets) > 2:
            assert len(sets[2].r) == len(sets[1].r)

        for group in att_mols:
            for i in range(len(sets[0].r)):
                ref1 = sets[0].r[i]
                if ref1[0] != group[0].molid:
                    continue
                val1 = sets[0].y[i]
                typeref1 = [group[0].types[row] for row in ref1[1:]]

                for j in range(len(sets[0].r)):

                    ref2 = sets[1].r[j]

                    if ref2[0] != group[1].molid:
                        continue
                    val2 = sets[1].y[j]
                    typeref2 = [group[1].types[row] for row in ref2[1:]]

                    bad = False
                    for xx in range(1, len(ref1)):
                        if ref1[xx] != ref2[xx]:
                            bad = True
                    if typeref1 != typeref2:
                        #print(typeref1, typeref2)
                        bad = True
                    if bad:
                        continue

                    if len(sets) > 2:
                        for k in range(len(sets[0].r)):
                            ref3 = sets[2].r[k]
                            if ref3[0] != group[2].molid:
                                continue
                            val3 = sets[2].y[k]
                            typeref3 = [
                                group[2].types[row] for row in ref3[1:]
                            ]

                            bad = False
                            for xx in range(1, len(ref1)):
                                if ref1[xx] != ref3[xx]:
                                    bad = True
                            if bad:
                                continue

                            refs.append([ref1, ref2, ref3])
                            values.append([val1, val2, val3])
                            typerefs.append([typeref1, typeref2, typeref3])
                    else:
                        refs.append([ref1, ref2])
                        values.append([val1, val2])
                        typerefs.append([typeref1, typeref2])

        if 'output_path' in args:
            assert len(args['output_path']) != 0
            outname = args['output_path'] + '/Comparison_' + str(
                targetflag) + '.csv'
        else:
            outname = 'Comparison_' + str(targetflag) + '.csv'

        print_mol_csv(outname, refs, typerefs, values, args['comp_labels'])

        x = [row[0] for row in values]
        y = [row[1] for row in values]

        MAE = np.mean(np.absolute(np.asarray(x) - np.asarray(y)))
        MAEstring = '{0:<6.3f}'.format(MAE)
        print('MAE between ', args['comp_labels'][0], args['comp_labels'][1],
              ' = ', MAEstring, '   no. of envs. ', len(x))
Пример #7
0
    def get_features_frommols(self,
                              args,
                              params={},
                              molcheck_run=False,
                              training=True):

        featureflag = args['featureflag']
        targetflag = args['targetflag']
        try:
            max = args['max_size']
        except:
            max = 200

        for mol in self.mols:
            if len(mol.types) > max:
                max = len(mol.types)
                print('WARNING, SETTING MAXIMUM MOLECULE SIZE TO, ', max)

        if 'cutoff' in params:
            if params['cutoff'] < 0.1:
                params['cutoff'] = 0.1
        else:
            params['cutoff'] = 5.0

        x = []
        y = []
        r = []

        self.params = params

        target = flag_to_target(targetflag)
        self.remove_mols(target)
        if molcheck_run:
            return

        if featureflag in ['aSLATM', 'CMAT', 'FCHL', 'ACSF']:
            import qml
        elif featureflag in ['BCAI']:
            from autoenrich.ml.features import TFM_features

        _, y, r = GNR_features.get_dummy_features(self.mols, targetflag)

        if featureflag == 'aSLATM':
            mbtypes = [[1], [1, 1], [1, 1, 1], [1, 1, 6], [1, 1, 7], [1, 1, 8],
                       [1, 1, 9], [1, 6], [1, 6, 1], [1, 6, 6], [1, 6, 7],
                       [1, 6, 8], [1, 6, 9], [1, 7], [1, 7, 1], [1, 7, 6],
                       [1, 7, 7], [1, 7, 8], [1, 7, 9], [1, 8], [1, 8, 1],
                       [1, 8, 6], [1, 8, 7], [1, 8, 8], [1, 8, 9], [1, 9],
                       [1, 9, 1], [1, 9, 6], [1, 9, 7], [1, 9, 8], [1, 9, 9],
                       [6], [6, 1], [6, 1, 1], [6, 1, 6], [6, 1, 7], [6, 1, 8],
                       [6, 1, 9], [6, 6], [6, 6, 1], [6, 6, 6], [6, 6, 7],
                       [6, 6, 8], [6, 6, 9], [6, 7], [6, 7, 1], [6, 7, 6],
                       [6, 7, 7], [6, 7, 8], [6, 7, 9], [6, 8], [6, 8, 1],
                       [6, 8, 6], [6, 8, 7], [6, 8, 8], [6, 8, 9], [6, 9],
                       [6, 9, 1], [6, 9, 6], [6, 9, 7], [6, 9, 8], [6, 9, 9],
                       [7], [7, 1], [7, 1, 1], [7, 1, 6], [7, 1, 7], [7, 1, 8],
                       [7, 1, 9], [7, 6], [7, 6, 1], [7, 6, 6], [7, 6, 7],
                       [7, 6, 8], [7, 6, 9], [7, 7], [7, 7, 1], [7, 7, 6],
                       [7, 7, 7], [7, 7, 8], [7, 7, 9], [7, 8], [7, 8, 1],
                       [7, 8, 6], [7, 8, 7], [7, 8, 8], [7, 8, 9], [7, 9],
                       [7, 9, 1], [7, 9, 6], [7, 9, 7], [7, 9, 8], [7, 9, 9],
                       [8], [8, 1], [8, 1, 1], [8, 1, 6], [8, 1, 7], [8, 1, 8],
                       [8, 1, 9], [8, 6], [8, 6, 1], [8, 6, 6], [8, 6, 7],
                       [8, 6, 8], [8, 6, 9], [8, 7], [8, 7, 1], [8, 7, 6],
                       [8, 7, 7], [8, 7, 8], [8, 7, 9], [8, 8], [8, 8, 1],
                       [8, 8, 6], [8, 8, 7], [8, 8, 8], [8, 8, 9], [8, 9],
                       [8, 9, 1], [8, 9, 6], [8, 9, 7], [8, 9, 8], [8, 9, 9],
                       [9], [9, 1], [9, 1, 1], [9, 1, 6], [9, 1, 7], [9, 1, 8],
                       [9, 1, 9], [9, 6], [9, 6, 1], [9, 6, 6], [9, 6, 7],
                       [9, 6, 8], [9, 6, 9], [9, 7], [9, 7, 1], [9, 7, 6],
                       [9, 7, 7], [9, 7, 8], [9, 7, 9], [9, 8], [9, 8, 1],
                       [9, 8, 6], [9, 8, 7], [9, 8, 8], [9, 8, 9], [9, 9],
                       [9, 9, 1], [9, 9, 6], [9, 9, 7], [9, 9, 8], [9, 9, 9]]
            '''
			nuclear_charges = []
			for tmp_mol in mols:
				nuclear_charges.append(tmp_mol.types)
			mbtypes = qml.representations.get_slatm_mbtypes(nuclear_charges)
			'''
            reps = qml.representations.generate_slatm(mol.xyz,
                                                      mol.types,
                                                      mbtypes,
                                                      rcut=cutoff)
            x = np.asarray(reps)

        elif featureflag == 'CMAT':
            reps = qml.representations.generate_atomic_coulomb_matrix(
                mol.types, mol.xyz, size=max, central_cutoff=cutoff)
            x = np.asarray(reps)

        elif featureflag == 'FCHL':
            reps = qml.fchl.generate_representation(mol.xyz,
                                                    mol.types,
                                                    max,
                                                    cut_distance=cutoff)
            x = np.asarray(reps)

        elif featureflag == 'ACSF':
            reps = qml.representations.generate_acsf(
                mol.types,
                mol.xyz,
                elements=[1, 6, 7, 8, 9, 14, 15, 16, 17, 35],
                nRs2=int(nRs2),
                nRs3=int(nRs3),
                nTs=int(nTs),
                eta2=eta2,
                eta3=eta3,
                zeta=zeta,
                rcut=cutoff,
                acut=acut,
                bin_min=0.0,
                gradients=False)
            x = np.asarray(reps)

        elif featureflag == 'BCAI':

            _x, _y, _r, mol_order = TFM_features.get_BCAI_features(
                self.mols, targetflag, training=training)

            x.extend(_x)
            y.extend(_y)
            r.extend(_r)
            batch_mols = []

        else:
            print('Feature flag not recognised, no feature flag: ',
                  featureflag)

        if featureflag == 'BCAI':
            self.x = x
            self.y = y
            self.r = r
            self.mol_order = mol_order
        else:
            self.x = np.asarray(x)
            self.y = np.asarray(y)
            self.r = r

        if featureflag not in ['dummy', 'BCAI']:
            print('Reps generated, shape: ', self.x.shape)