def get_candidates(n = 10):
	'''
	Pull n example reactions, their candidates, and the true answer
	'''

	from ochem_predict_nn.utils.database import collection_candidates
	examples = collection_candidates()

	# Define generator
	class Generator():
		def __init__(self):
			self.done_ids = set()
			self.done_smiles = set()

		def get_sequential(self):
			'''Sequential'''
			for doc in examples.find({'found': True}, no_cursor_timeout = True):
				try:
					if not doc: continue 
					if doc['_id'] in self.done_ids: continue
					if doc['reactant_smiles'] in self.done_smiles: 
						print('New ID {}, but old reactant SMILES {}'.format(doc['_id'], doc['reactant_smiles']))
						continue
					self.done_ids.add(doc['_id'])
					self.done_smiles.add(doc['reactant_smiles'])
					yield doc
				except KeyboardInterrupt:
					print('Terminated early')
					quit(1)

	gen = Generator()
	generator = enumerate(gen.get_sequential())

	# Initialize (this is not the best way to do this...)
	reaction_candidate_edits = []
	reaction_candidate_smiles = []
	reaction_true_onehot = []
	reaction_true = []
	counter = 0
	for i, reaction in generator:

		candidate_smiles = [a for (a, b) in reaction['edit_candidates']]
		candidate_edits =    [b for (a, b) in reaction['edit_candidates']]
		reactant_smiles = reaction['reactant_smiles']
		product_smiles_true = reaction['product_smiles_true']

		reactants_check = Chem.MolFromSmiles(str(reactant_smiles))
		if not reactants_check:
			print('######### Could not parse reactants - that is weird...')
			print(reactant_smiles)
			continue

		bools = [product_smiles_true == x for x in candidate_smiles]
		print('rxn. {} : {} true entries out of {}'.format(i, sum(bools), len(bools)))
		if sum(bools) > 1:
			print('More than one true? Will take first one')
			pass
		if sum(bools) == 0:
			print('##### True product not found / filtered out #####')
			continue

		# Sort together and append
		zipsort = sorted(zip(bools, candidate_smiles, candidate_edits))
		zipsort = [[(y, z, x) for (y, z, x) in zipsort if y == 1][0]] + \
				  [(y, z, x) for (y, z, x) in zipsort if y == 0]

		if sum([y for (y, z, x) in zipsort]) != 1:
			print('New sum true: {}'.format(sum([y for (y, z, x) in zipsort])))
			print('## wrong number of true results?')
			raw_input('Pausing...')

		reaction_candidate_edits_compact = [
				';'.join([
					','.join(x[0]),
					','.join(x[1]),
					','.join(['%s-%s-%s' % tuple(blost) for blost in x[2]]),
					','.join(['%s-%s-%s' % tuple(bgain) for bgain in x[3]]),
				])
			for (y, z, x) in zipsort]

		# Use fingerprint length 1024
		prod_FPs = np.zeros((len(zipsort), 1024), dtype = bool)
		for i, candidate in enumerate([z for (y, z, x) in zipsort]):
			try:
				prod = Chem.MolFromSmiles(str(candidate))
				prod_FPs[i, :] = np.array(AllChem.GetMorganFingerprintAsBitVect(prod, 2, nBits = 1024), dtype = bool)
			except Exception as e:
				print(e)
				continue

		pickle.dump((
			reaction_candidate_edits_compact,
			edits_to_vectors([], reactants_check, return_atom_desc_dict = True, ORIGINAL_VERSION = True),
			prod_FPs,
			[y for (y, z, x) in zipsort],
		), fid_data, pickle.HIGHEST_PROTOCOL)

		pickle.dump((
			str(reaction['_id']),
			str(reactant_smiles) + '>>' + str(product_smiles_true) + '[{}]'.format(len(zipsort)),
			[z for (y, z, x) in zipsort],
			reaction_candidate_edits_compact,
		), fid_labels, pickle.HIGHEST_PROTOCOL)

		counter += 1


		if counter == n: break
	return reaction_candidate_edits, reaction_true_onehot, reaction_candidate_smiles, reaction_true
                        help='Mincount of templates, default 50')
    parser.add_argument('-v', type=int, default=1, help='Verbose? default 1')
    args = parser.parse_args()

    v = bool(int(args.v))
    FROOT = os.path.join(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'main',
        'output', str(args.tag))

    MODEL_FPATH = os.path.join(FROOT, 'model.json')
    WEIGHTS_FPATH = os.path.join(FROOT, 'weights.h5')
    ARGS_FPATH = os.path.join(FROOT, 'args.json')

    mol = Chem.MolFromSmiles('[C:1][C:2]')
    (a, _, b, _) = edits_to_vectors((['1'], [], [('1', '2', 1.0)], []),
                                    mol,
                                    ORIGINAL_VERSION=True)
    F_atom = len(a[0])
    F_bond = len(b[0])

    # Silence warnings
    from rdkit import RDLogger
    lg = RDLogger.logger()
    lg.setLevel(4)

    # Load transformer
    from ochem_predict_nn.utils.database import collection_templates
    templates = collection_templates()
    Transformer = transformer.Transformer()
    Transformer.load(templates,
                     mincount=int(args.mincount),
def preprocess_candidate_edits(reactants, candidate_list):
    candidate_smiles = [a for (a, b) in candidate_list]
    candidate_edits = [b for (a, b) in candidate_list]

    print('Generated {} unique edit sets'.format(len(candidate_list)))
    padUpTo = len(candidate_list)
    N_e1 = 20
    N_e2 = 20
    N_e3 = 20
    N_e4 = 20

    N_e1_trim = 1
    N_e2_trim = 1
    N_e3_trim = 1
    N_e4_trim = 1

    # Initialize
    x_h_lost = np.zeros((1, padUpTo, N_e1, F_atom))
    x_h_gain = np.zeros((1, padUpTo, N_e2, F_atom))
    x_bond_lost = np.zeros((1, padUpTo, N_e3, F_bond))
    x_bond_gain = np.zeros((1, padUpTo, N_e4, F_bond))
    x = np.zeros((1, padUpTo, 1024))

    # Get reactant descriptors
    atom_desc_dict = edits_to_vectors([],
                                      reactants,
                                      return_atom_desc_dict=True,
                                      ORIGINAL_VERSION=True)

    # Populate arrays
    for (c, edits) in enumerate(candidate_edits):
        if c == padUpTo: break
        edit_h_lost_vec, edit_h_gain_vec, \
         edit_bond_lost_vec, edit_bond_gain_vec = edits_to_vectors(edits, reactants, atom_desc_dict = atom_desc_dict, ORIGINAL_VERSION = True)

        N_e1_trim = max(N_e1_trim, len(edit_h_lost_vec))
        N_e2_trim = max(N_e2_trim, len(edit_h_gain_vec))
        N_e3_trim = max(N_e3_trim, len(edit_bond_lost_vec))
        N_e4_trim = max(N_e4_trim, len(edit_bond_gain_vec))

        for (e, edit_h_lost) in enumerate(edit_h_lost_vec):
            if e >= N_e1: continue
            x_h_lost[0, c, e, :] = edit_h_lost
        for (e, edit_h_gain) in enumerate(edit_h_gain_vec):
            if e >= N_e2: continue
            x_h_gain[0, c, e, :] = edit_h_gain
        for (e, edit_bond_lost) in enumerate(edit_bond_lost_vec):
            if e >= N_e3: continue
            x_bond_lost[0, c, e, :] = edit_bond_lost
        for (e, edit_bond_gain) in enumerate(edit_bond_gain_vec):
            if e >= N_e4: continue
            x_bond_gain[0, c, e, :] = edit_bond_gain

        if BASELINE_MODEL or HYBRID_MODEL:
            prod = Chem.MolFromSmiles(str(candidate_smiles[c]))
            if prod is not None:
                x[0, c, :] = np.array(AllChem.GetMorganFingerprintAsBitVect(
                    prod, 2, nBits=1024),
                                      dtype=bool)

    # Trim down
    x_h_lost = x_h_lost[:, :, :N_e1_trim, :]
    x_h_gain = x_h_gain[:, :, :N_e2_trim, :]
    x_bond_lost = x_bond_lost[:, :, :N_e3_trim, :]
    x_bond_gain = x_bond_gain[:, :, :N_e4_trim, :]

    # Get rid of NaNs
    x_h_lost[np.isnan(x_h_lost)] = 0.0
    x_h_gain[np.isnan(x_h_gain)] = 0.0
    x_bond_lost[np.isnan(x_bond_lost)] = 0.0
    x_bond_gain[np.isnan(x_bond_gain)] = 0.0
    x_h_lost[np.isinf(x_h_lost)] = 0.0
    x_h_gain[np.isinf(x_h_gain)] = 0.0
    x_bond_lost[np.isinf(x_bond_lost)] = 0.0
    x_bond_gain[np.isinf(x_bond_gain)] = 0.0

    if BASELINE_MODEL:
        return [x]
    elif HYBRID_MODEL:
        return [x_h_lost, x_h_gain, x_bond_lost, x_bond_gain, x]

    return [x_h_lost, x_h_gain, x_bond_lost, x_bond_gain]
Example #4
0
def data_generator(start_at,
                   end_at,
                   batch_size,
                   max_N_c=None,
                   shuffle=False,
                   allowable_batchNums=set()):
    '''This function generates batches of data from the
	pickle file since all the data can't fit in memory.

	The starting and ending indices are specified explicitly so the
	same function can be used for validation data as well

	Input tensors are generated on-the-fly so there is less I/O

	max_N_c is the maximum number of candidates to consider. This should ONLY be used
	for training, not for validation or testing.

	"mybatchnums" is a new list that contains the batch indices (across the whole dataset) that belong
	to this particular generator. This allows for CV splitting *outside* of this function.'''
    def bond_string_to_tuple(string):
        split = string.split('-')
        return (split[0], split[1], float(split[2]))

    fileInfo = [() for j in range(start_at, end_at, batch_size)
                ]  # (filePos, startIndex, endIndex)
    batchDims = [() for j in range(start_at, end_at, batch_size)
                 ]  # dimensions of each batch
    batchNums = np.array([
        i for (i, j) in enumerate(range(start_at, end_at, batch_size))
    ])  # list to shuffle later

    # Keep returning forever and ever
    with open(DATA_FPATH, 'rb') as fid:

        # Do a first pass through the data
        legend_data = pickle.load(fid)  # first doc is legend

        # Pre-load indeces
        CANDIDATE_EDITS_COMPACT = legend_data['candidate_edits_compact']
        ATOM_DESC_DICT = legend_data['atom_desc_dict']
        REACTION_TRUE_ONEHOT = legend_data['reaction_true_onehot']

        for i in range(start_at):
            pickle.load(fid)  # throw away first ___ entries

        for k, startIndex in enumerate(range(start_at, end_at, batch_size)):
            endIndex = min(startIndex + batch_size, end_at)

            # Remember this starting position
            fileInfo[k] = (fid.tell(), startIndex, endIndex)

            N = endIndex - startIndex  # number of samples this batch
            # print('Serving up examples {} through {}'.format(startIndex, endIndex))

            docs = [pickle.load(fid) for j in range(startIndex, endIndex)]

            # FNeed to figure out size of padded batch
            N_c = max([len(doc[REACTION_TRUE_ONEHOT]) for doc in docs])
            if type(max_N_c) != type(None):  # allow truncation during training
                N_c = min(N_c, max_N_c)
            N_e1 = 1
            N_e2 = 1
            N_e3 = 1
            N_e4 = 1
            for i, doc in enumerate(docs):
                for (c,
                     edit_string) in enumerate(doc[CANDIDATE_EDITS_COMPACT]):
                    if c >= N_c: break
                    edit_string_split = edit_string.split(';')
                    N_e1 = max(N_e1, edit_string_split[0].count(',') + 1)
                    N_e2 = max(N_e2, edit_string_split[1].count(',') + 1)
                    N_e3 = max(N_e3, edit_string_split[2].count(',') + 1)
                    N_e4 = max(N_e4, edit_string_split[3].count(',') + 1)

            # Remember sizes of x_h_lost, x_h_gain, x_bond_lost, x_bond_gain, reaction_true_onehot
            batchDim = (N, N_c, N_e1, N_e2, N_e3, N_e4)

            # print('The padded sizes of this batch will be: N, N_c, N_e1, N_e2, N_e3, N_e4')
            # print(batchDim)
            batchDims[k] = batchDim

        while True:

            if shuffle: np.random.shuffle(batchNums)

            for batchNum in batchNums:
                if batchNum not in allowable_batchNums: continue
                #print('data grabbed batchNum {}'.format(batchNum))

                (filePos, startIndex, endIndex) = fileInfo[batchNum]
                (N, N_c, N_e1, N_e2, N_e3, N_e4) = batchDims[batchNum]
                fid.seek(filePos)

                N = endIndex - startIndex  # number of samples this batch
                # print('Serving up examples {} through {}'.format(startIndex, endIndex))

                docs = [pickle.load(fid) for j in range(startIndex, endIndex)]

                if BASELINE_MODEL or HYBRID_MODEL:
                    x = np.zeros((N, N_c, 1024), dtype=np.float32)

                # Initialize numpy arrays for x_h_lost, etc.
                x_h_lost = np.zeros((N, N_c, N_e1, F_atom), dtype=np.float32)
                x_h_gain = np.zeros((N, N_c, N_e2, F_atom), dtype=np.float32)
                x_bond_lost = np.zeros((N, N_c, N_e3, F_bond),
                                       dtype=np.float32)
                x_bond_gain = np.zeros((N, N_c, N_e4, F_bond),
                                       dtype=np.float32)
                reaction_true_onehot = np.zeros((N, N_c), dtype=np.float32)

                for i, doc in enumerate(docs):

                    for (c, edit_string) in enumerate(
                            doc[CANDIDATE_EDITS_COMPACT]):
                        if c >= N_c:
                            break

                        if BASELINE_MODEL or HYBRID_MODEL:
                            x[i, c, :] = doc[legend_data['prod_FPs']][c]

                        edit_string_split = edit_string.split(';')
                        edits = [
                            [
                                atom_string for atom_string in
                                edit_string_split[0].split(',') if atom_string
                            ],
                            [
                                atom_string for atom_string in
                                edit_string_split[1].split(',') if atom_string
                            ],
                            [
                                bond_string_to_tuple(bond_string) for
                                bond_string in edit_string_split[2].split(',')
                                if bond_string
                            ],
                            [
                                bond_string_to_tuple(bond_string) for
                                bond_string in edit_string_split[3].split(',')
                                if bond_string
                            ],
                        ]

                        try:
                            edit_h_lost_vec, edit_h_gain_vec, \
                             edit_bond_lost_vec, edit_bond_gain_vec = edits_to_vectors(edits, None, atom_desc_dict = doc[ATOM_DESC_DICT], ORIGINAL_VERSION = True)
                        except KeyError as e:  # sometimes molAtomMapNumber not found if hydrogens were explicit
                            continue

                        for (e, edit_h_lost) in enumerate(edit_h_lost_vec):
                            if e >= N_e1:
                                raise ValueError('N_e1 not large enough!')
                            x_h_lost[i, c, e, :] = edit_h_lost
                        for (e, edit_h_gain) in enumerate(edit_h_gain_vec):
                            if e >= N_e2:
                                raise ValueError('N_e2 not large enough!')
                            x_h_gain[i, c, e, :] = edit_h_gain
                        for (e,
                             edit_bond_lost) in enumerate(edit_bond_lost_vec):
                            if e >= N_e3:
                                raise ValueError('N_e3 not large enough!')
                            x_bond_lost[i, c, e, :] = edit_bond_lost
                        for (e,
                             edit_bond_gain) in enumerate(edit_bond_gain_vec):
                            if e >= N_e4:
                                raise ValueRrror('N_e4 not large enough!')
                            x_bond_gain[i, c, e, :] = edit_bond_gain

                    # Add truncated reaction true (eventually will not truncate)
                    if type(max_N_c) == type(None):
                        reaction_true_onehot[
                            i, :len(doc[REACTION_TRUE_ONEHOT]
                                    )] = doc[REACTION_TRUE_ONEHOT]
                    else:
                        reaction_true_onehot[
                            i, :min(len(doc[REACTION_TRUE_ONEHOT]), max_N_c
                                    )] = doc[REACTION_TRUE_ONEHOT][:max_N_c]

                # Get rid of NaNs
                x_h_lost[np.isnan(x_h_lost)] = 0.0
                x_h_gain[np.isnan(x_h_gain)] = 0.0
                x_bond_lost[np.isnan(x_bond_lost)] = 0.0
                x_bond_gain[np.isnan(x_bond_gain)] = 0.0
                x_h_lost[np.isinf(x_h_lost)] = 0.0
                x_h_gain[np.isinf(x_h_gain)] = 0.0
                x_bond_lost[np.isinf(x_bond_lost)] = 0.0
                x_bond_gain[np.isinf(x_bond_gain)] = 0.0

                # print('Batch {} to {}'.format(startIndex, endIndex))
                # yield (x, y) as tuple, but each one is a list

                y = reaction_true_onehot

                if BASELINE_MODEL:
                    yield ([x], [y])

                elif HYBRID_MODEL:
                    yield (
                        [x_h_lost, x_h_gain, x_bond_lost, x_bond_gain, x],
                        [
                            y,
                        ],
                    )

                else:

                    yield (
                        [
                            x_h_lost,
                            x_h_gain,
                            x_bond_lost,
                            x_bond_gain,
                        ],
                        [
                            y,
                        ],
                    )