Beispiel #1
0
def data_generator(start_at, end_at, batch_size, max_N_c=None, shuffle=False):
    '''This function generates batches of data from the
    pickle file since all the data can't fit in memory.

    The starting and ending indices are specified explicitly so the
    same function can be used for validation data as well

    Input tensors are generated on-the-fly so there is less I/O

    max_N_c is the maximum number of candidates to consider. This should ONLY be used
    for training, not for validation or testing.'''
    def bond_string_to_tuple(string):
        split = string.split('-')
        return (split[0], split[1], float(split[2]))

    fileInfo = [() for j in range(start_at, end_at, batch_size)
                ]  # (filePos, startIndex, endIndex)
    batchDims = [() for j in range(start_at, end_at, batch_size)
                 ]  # dimensions of each batch
    batchNums = np.array([
        i for (i, j) in enumerate(range(start_at, end_at, batch_size))
    ])  # list to shuffle later

    # Keep returning forever and ever
    with open(DATA_FPATH, 'rb') as fid:

        # Do a first pass through the data
        legend_data = pickle.load(fid)  # first doc is legend

        # Pre-load indeces
        CANDIDATE_EDITS_COMPACT = legend_data['candidate_edits_compact']
        ATOM_DESC_DICT = legend_data['atom_desc_dict']
        T = legend_data['T']
        SOLVENT = legend_data['solvent']
        REAGENT = legend_data['reagent']
        YIELD = legend_data['yield']
        REACTION_TRUE_ONEHOT = legend_data['reaction_true_onehot']

        for i in range(start_at):
            pickle.load(fid)  # throw away first ___ entries

        for k, startIndex in enumerate(range(start_at, end_at, batch_size)):
            endIndex = min(startIndex + batch_size, end_at)

            # Remember this starting position
            fileInfo[k] = (fid.tell(), startIndex, endIndex)

            N = endIndex - startIndex  # number of samples this batch
            # print('Serving up examples {} through {}'.format(startIndex, endIndex))

            docs = [pickle.load(fid) for j in range(startIndex, endIndex)]

            # FNeed to figure out size of padded batch
            N_c = max([len(doc[REACTION_TRUE_ONEHOT]) for doc in docs])
            if type(max_N_c) != type(None):  # allow truncation during training
                N_c = min(N_c, max_N_c)
            N_e1 = 1
            N_e2 = 1
            N_e3 = 1
            N_e4 = 1
            for i, doc in enumerate(docs):
                for (c,
                     edit_string) in enumerate(doc[CANDIDATE_EDITS_COMPACT]):
                    if c >= N_c: break
                    edit_string_split = edit_string.split(';')
                    N_e1 = max(N_e1, edit_string_split[0].count(',') + 1)
                    N_e2 = max(N_e2, edit_string_split[1].count(',') + 1)
                    N_e3 = max(N_e3, edit_string_split[2].count(',') + 1)
                    N_e4 = max(N_e4, edit_string_split[3].count(',') + 1)

            # Remember sizes of x_h_lost, x_h_gain, x_bond_lost, x_bond_gain, reaction_true_onehot
            batchDim = (N, N_c, N_e1, N_e2, N_e3, N_e4)

            # print('The padded sizes of this batch will be: N, N_c, N_e1, N_e2, N_e3, N_e4')
            # print(batchDim)
            batchDims[k] = batchDim

        while True:

            if shuffle: np.random.shuffle(batchNums)

            for batchNum in batchNums:
                (filePos, startIndex, endIndex) = fileInfo[batchNum]
                (N, N_c, N_e1, N_e2, N_e3, N_e4) = batchDims[batchNum]
                fid.seek(filePos)

                N = endIndex - startIndex  # number of samples this batch
                # print('Serving up examples {} through {}'.format(startIndex, endIndex))

                docs = [pickle.load(fid) for j in range(startIndex, endIndex)]

                # Initialize numpy arrays for x_h_lost, etc.
                x_h_lost = np.zeros((N, N_c, N_e1, F_atom), dtype=np.float32)
                x_h_gain = np.zeros((N, N_c, N_e2, F_atom), dtype=np.float32)
                x_bond_lost = np.zeros((N, N_c, N_e3, F_bond),
                                       dtype=np.float32)
                x_bond_gain = np.zeros((N, N_c, N_e4, F_bond),
                                       dtype=np.float32)
                reaction_true_onehot = np.zeros((N, N_c), dtype=np.float32)
                yields = np.zeros((N, 1), dtype=np.float32)

                for i, doc in enumerate(docs):

                    for (c, edit_string) in enumerate(
                            doc[CANDIDATE_EDITS_COMPACT]):
                        if c >= N_c:
                            break

                        edit_string_split = edit_string.split(';')
                        edits = [
                            [
                                atom_string for atom_string in
                                edit_string_split[0].split(',') if atom_string
                            ],
                            [
                                atom_string for atom_string in
                                edit_string_split[1].split(',') if atom_string
                            ],
                            [
                                bond_string_to_tuple(bond_string) for
                                bond_string in edit_string_split[2].split(',')
                                if bond_string
                            ],
                            [
                                bond_string_to_tuple(bond_string) for
                                bond_string in edit_string_split[3].split(',')
                                if bond_string
                            ],
                        ]

                        try:
                            edit_h_lost_vec, edit_h_gain_vec, \
                                edit_bond_lost_vec, edit_bond_gain_vec = edits_to_vectors(edits, None, atom_desc_dict = doc[ATOM_DESC_DICT])
                        except KeyError as e:  # sometimes molAtomMapNumber not found if hydrogens were explicit
                            continue

                        for (e, edit_h_lost) in enumerate(edit_h_lost_vec):
                            if e >= N_e1:
                                raise ValueError('N_e1 not large enough!')
                            x_h_lost[i, c, e, :] = edit_h_lost
                        for (e, edit_h_gain) in enumerate(edit_h_gain_vec):
                            if e >= N_e2:
                                raise ValueError('N_e2 not large enough!')
                            x_h_gain[i, c, e, :] = edit_h_gain
                        for (e,
                             edit_bond_lost) in enumerate(edit_bond_lost_vec):
                            if e >= N_e3:
                                raise ValueError('N_e3 not large enough!')
                            x_bond_lost[i, c, e, :] = edit_bond_lost
                        for (e,
                             edit_bond_gain) in enumerate(edit_bond_gain_vec):
                            if e >= N_e4:
                                raise ValueRrror('N_e4 not large enough!')
                            x_bond_gain[i, c, e, :] = edit_bond_gain

                    # Add truncated reaction true (eventually will not truncate)
                    if type(max_N_c) == type(None):
                        reaction_true_onehot[
                            i, :len(doc[REACTION_TRUE_ONEHOT]
                                    )] = doc[REACTION_TRUE_ONEHOT]
                    else:
                        reaction_true_onehot[
                            i, :min(len(doc[REACTION_TRUE_ONEHOT]), max_N_c
                                    )] = doc[REACTION_TRUE_ONEHOT][:max_N_c]
                    yields[i, 0] = doc[YIELD] / 100.0

                # Get rid of NaNs
                x_h_lost[np.isnan(x_h_lost)] = 0.0
                x_h_gain[np.isnan(x_h_gain)] = 0.0
                x_bond_lost[np.isnan(x_bond_lost)] = 0.0
                x_bond_gain[np.isnan(x_bond_gain)] = 0.0
                x_h_lost[np.isinf(x_h_lost)] = 0.0
                x_h_gain[np.isinf(x_h_gain)] = 0.0
                x_bond_lost[np.isinf(x_bond_lost)] = 0.0
                x_bond_gain[np.isinf(x_bond_gain)] = 0.0

                # print('Batch {} to {}'.format(startIndex, endIndex))
                # yield (x, y) as tuple, but each one is a list

                if TARGET_YIELD:
                    y = yields
                else:
                    y = reaction_true_onehot

                yield (
                    [
                        x_h_lost,
                        x_h_gain,
                        x_bond_lost,
                        x_bond_gain,
                        np.array([doc[REAGENT] for doc in docs],
                                 dtype=np.float32),  # reagent
                        np.array([doc[SOLVENT] for doc in docs],
                                 dtype=np.float32),  # solvent
                        np.array([doc[T] for doc in docs],
                                 dtype=np.float32),  # temperature
                    ],
                    [
                        y,
                    ],
                )
Beispiel #2
0
    def evaluate(self, reactants_smiles, contexts, **kwargs):
        self.reset()
        self.nproc = kwargs.pop('nproc', 1)
        batch_size = kwargs.pop('batch_size', 250)
        if not self.celery:
            for i in range(self.nproc):
                self.idle.append(True)
                self.expansion_queue = Queue()

        mol = Chem.MolFromSmiles(reactants_smiles)
        if mol is None: 
            MyLogger.print_and_log('Reactants smiles not parsible: {}'.format(
                    reactants_smiles), template_nn_scorer_loc, level=1)
            return [[{
                        'rank': 1,
                        'outcome': '',
                        'score': 0,
                        'prob': 0,
                        }]]

        clean_reactant_mapping(mol)

        reactants_smiles = Chem.MolToSmiles(mol)
        if self.celery:
            from celery.result import allow_join_result
        else:
            from makeit.utilities.with_dummy import with_dummy as allow_join_result
        with allow_join_result():
            self.template_prioritization = kwargs.pop('template_prioritization', gc.popularity)
            self.prepare()
            self.initialize(reactants_smiles, batch_size)
            (all_results, candidate_edits) = self.get_candidate_edits(reactants_smiles)
            reactants = Chem.MolFromSmiles(reactants_smiles)
            atom_desc_dict = edits_to_vectors(
                [], reactants, return_atom_desc_dict=True)
            candidate_tensor = edits_to_tensor(
                candidate_edits, reactants, atom_desc_dict)

            if not candidate_tensor:
                return [[{
                        'rank': 1,
                        'outcome': '',
                        'score': 0,
                        'prob': 0,
                        }]]

            all_outcomes = []
            for context in contexts:
                if context == []:
                    all_outcomes.append({'rank': 0.0,
                                         'outcome': None,
                                         'score': 0.0,
                                         'prob': 0.0,
                                         })
                    continue
                # prediction_context = context_cleaner.clean_context(context) ## move this step to tree evaluator
                prediction_context = context
                context_tensor = context_cleaner.context_to_edit(
                    prediction_context, self.solvent_name_to_smiles, self.solvent_smiles_to_params)
                if not context_tensor: 
                    all_outcomes.append({'rank': 0.0,
                                         'outcome': None,
                                         'score': 0.0,
                                         'prob': 0.0,
                                         })
                    continue
                scores = self.model.predict(candidate_tensor + context_tensor)
                probs = scores

                if kwargs.pop('soft_max', True):
                    probs = softmax(scores)

                this_outcome = sorted(zip(all_results, scores[0], probs[
                                      0]), key=lambda x: x[2], reverse=True)

                # Convert to outcome dict, canonicalizing by SMILES
                outcome_dict = {}
                for i, outcome in enumerate(this_outcome):
                    try:
                        outcome_smiles = outcome[0].smiles
                    except AttributeError:
                        outcome_smiles = outcome[0]['smiles']
                    if outcome_smiles not in outcome_dict:
                        outcome_dict[outcome_smiles] = {
                            'rank': i + 1,
                            'outcome': outcome[0],
                            'score': float(outcome[1]),
                            'prob': float(outcome[2]),
                        }
                    else: # just add probability
                        outcome_dict[outcome_smiles]['prob'] += float(outcome[2])

                all_outcomes.append(sorted(list(outcome_dict.values()), key=lambda x: x['prob'], reverse=True))

            return all_outcomes
Beispiel #3
0
    WEIGHTS_FPATH = os.path.join(FROOT, 'weights.h5')
    HIST_FPATH = os.path.join(FROOT, 'hist.csv')
    TEST_FPATH = os.path.join(FROOT, 'probs.dat')
    HISTOGRAM_FPATH = os.path.join(FROOT, 'histogram %s.png')
    ARGS_FPATH = os.path.join(FROOT, 'args.json')

    with open(ARGS_FPATH, 'w') as fid:
        import json
        json.dump(args.__dict__, fid)

    DATA_FPATH = '{}_data.pickle'.format(args.data_tag)
    LABELS_FPATH = '{}_labels.pickle'.format(args.data_tag)

    this_dir = os.getcwd()
    mol = Chem.MolFromSmiles('[CH3:1][CH3:2]')
    (a, _, b, _) = edits_to_vectors((['1'], [], [('1', '2', 1.0)], []), mol)
    os.chdir(this_dir)

    F_atom = len(a[0])
    F_bond = len(b[0])

    if bool(args.retrain):
        print('Reloading from file')
        rebuild = raw_input(
            'Do you want to rebuild from scratch instead of loading from file? [n/y] '
        )
        if rebuild == 'y':
            model = build(F_atom=F_atom,
                          F_bond=F_bond,
                          N_h1=N_h1,
                          N_h2=N_h2,