Esempio n. 1
0
def load_Pricer(chemical_database, buyable_database):
    '''
    Load a pricer using the chemicals database and database of buyable chemicals
    '''
    MyLogger.print_and_log('Loading pricing model...', model_loader_loc)
    pricerModel = Pricer()
    pricerModel.load(chemical_database, buyable_database)
    MyLogger.print_and_log('Pricer Loaded.', model_loader_loc)
    return pricerModel
Esempio n. 2
0
class RelevanceHeuristicPrecursorPrioritizer(Prioritizer):
    """A precursor Prioritizer that uses a heuristic and template relevance.

    Attributes:
        pricer (Pricer or None): Used to look up chemical prices.
    """
    def __init__(self):
        """Initializes RelevanceHeuristicPrecursorPrioritizer."""
        self.pricer = None
        self._loaded = False

    def get_priority(self, retroPrecursor, **kwargs):
        """Gets priority of given precursor based on heuristic and relevance.

        Args:
            retroPrecursor (RetroPrecursor): Precursor to calculate priority of.
            **kwargs: Unused.

        Returns:
            float: Priority score of precursor.
        """
        if not self._loaded:
            self.load_model()

        necessary_reagent_atoms = retroPrecursor.necessary_reagent.count('[') / 2.
        scores = []
        for smiles in retroPrecursor.smiles_list:
            # If buyable, basically free
            ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True)
            if ppg:
                scores.append(- ppg / 1000.0)
                continue

            # Else, use heuristic
            x = Chem.MolFromSmiles(smiles)
            total_atoms = x.GetNumHeavyAtoms()
            ring_bonds = sum([b.IsInRing() - b.GetIsAromatic()
                              for b in x.GetBonds()])
            chiral_centers = len(Chem.FindMolChiralCenters(x))

            scores.append(
                - 2.00 * np.power(total_atoms, 1.5)
                - 1.00 * np.power(ring_bonds, 1.5)
                - 2.00 * np.power(chiral_centers, 2.0)
            )

        sco = np.sum(scores) - 4.00 * np.power(necessary_reagent_atoms, 2.0)
        return sco / retroPrecursor.template_score

    def load_model(self):
        """Loads the Pricer used in the heuristic priority scoring."""
        self.pricer = Pricer()
        self.pricer.load()
        self._loaded = True
Esempio n. 3
0
def configure_coordinator(options={}, **kwargs):
    if 'queues' not in options:
        return
    if CORRESPONDING_QUEUE not in options['queues'].split(','):
        return
    print('### STARTING UP A TREE BUILDER COORDINATOR ###')

    global treeBuilder
    global evaluator

    evaluator = Evaluator(celery=True)

    # Prices
    print('Loading prices...')
    pricer = Pricer()
    pricer.load()
    print('Loaded known prices')
    treeBuilder = TreeBuilder(celery=True, pricer=pricer)

    print('Finished initializing treebuilder coordinator')
Esempio n. 4
0
class HeuristicPrecursorPrioritizer(Prioritizer):

    def __init__(self):
       
        self.pricer = None
        self._loaded = False

    def get_priority(self, retroPrecursor, **kwargs):
        if not self._loaded:
            self.load_model()

        necessary_reagent_atoms = retroPrecursor.necessary_reagent.count('[') / 2.
        scores = []
        for smiles in retroPrecursor.smiles_list:
            # If buyable, basically free
            ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True)
            if ppg:
                scores.append(- ppg / 5.0)
                continue

            # Else, use heuristic
            x = Chem.MolFromSmiles(smiles)
            total_atoms = x.GetNumHeavyAtoms()
            ring_bonds = sum([b.IsInRing() - b.GetIsAromatic()
                              for b in x.GetBonds()])
            chiral_centers = len(Chem.FindMolChiralCenters(x))

            scores.append(
                - 2.00 * np.power(total_atoms, 1.5)
                - 1.00 * np.power(ring_bonds, 1.5)
                - 2.00 * np.power(chiral_centers, 2.0)
            )

        return np.sum(scores) - 4.00 * np.power(necessary_reagent_atoms, 2.0)

    def load_model(self):
        self.pricer = Pricer()
        self.pricer.load()
        self._loaded = True
Esempio n. 5
0
class MinCostPrecursorPrioritizer(Prioritizer):
    '''
    This is a standalone, importable MinCost model. Uses loaded keras model.
    '''

    def __init__(self, score_scale=10.0):
        
        self.vars = []
        self.FP_rad = 3
        self.score_scale = score_scale
        
        self._restored = False
        self.pricer = None
        self._loaded = False
 
    def load_model(self, FP_len=1024, input_layer=6144, hidden_layer=512, modelpath=""):
        
        self.FP_len = FP_len

        def last_layer(x):
            return (lambda x: 11 - 11 * K.exp(- K.abs(x / 1500000)))(x)

        model = Sequential()        
        model.add(Dense(input_layer, activation="relu", batch_input_shape=(None, self.FP_len) ))
        model.add(Dense(hidden_layer, activation="relu"))
        model.add(Dense(hidden_layer, activation="relu"))
        model.add(Dense(hidden_layer, activation="relu"))
        model.add(Dense(hidden_layer, activation="relu"))    
        model.add(Dense(1))
        model.add(Lambda(last_layer, output_shape = (None, 1)))
        if modelpath == "":
            modelpath = gc.MinCost_Prioritiaztion['trained_model_path']
        weights = model.load_weights(modelpath)
        self.model = model 
     
        def mol_to_fp(mol):
            if mol is None:
                return np.zeros((self.FP_len,), dtype=np.float32)
            return np.array(AllChem.GetMorganFingerprintAsBitVect(mol, self.FP_rad, nBits=self.FP_len,
                                                                  useChirality=True), dtype=np.bool)
        self.mol_to_fp = mol_to_fp

        self.pricer = Pricer()
        self.pricer.load()
        self._restored = True
        self._loaded = True

    def smi_to_fp(self, smi):
        if not smi:
            return np.zeros((self.FP_len,), dtype=np.float32)
        return self.mol_to_fp(Chem.MolFromSmiles(smi))

    def get_price(self, smi):
        ppg = self.pricer.lookup_smiles(smi, alreadyCanonical=True)
        if ppg:
            return 0.0
        else:
            return None 

    def get_priority(self, retroProduct, **kwargs):
        mode = kwargs.pop('mode', gc.max)
        if not self._loaded:
            self.load_model()

        if not isinstance(retroProduct, str):
            scores = []
            for smiles in retroProduct.smiles_list:
                scores.append(self.get_score_from_smiles(smiles))
            return sum(scores)
        else:
            return self.get_score_from_smiles(retroProduct)
        if not retroProduct:
            return inf

    def get_score_from_smiles(self, smiles):
        # Check buyable
        ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True)
        if ppg:
            return 0.0 #ppg / 100.
        
        fp = np.array((self.smi_to_fp(smiles)), dtype=np.float32)
        if sum(fp) == 0:
            cur_score = 0.
        else:
            cur_score = self.model.predict(fp.reshape((1,self.FP_len)))[0][0]
        return cur_score
Esempio n. 6
0
class SCScorePrecursorPrioritizer(Prioritizer):
    '''
    This is a standalone, importable SCScorecorer model. It does not have tensorflow as a
    dependency and is a more attractive option for deployment. The calculations are 
    fast enough that there is no real reason to use GPUs (via tf) instead of CPUs (via np)
    '''

    def __init__(self, score_scale=5.0):
        self.vars = []
        self.FP_rad = 2
        self.score_scale = score_scale
        self._restored = False
        self.pricer = None
        self._loaded = False

    def load_model(self, FP_len=1024, model_tag='1024bool'):
        self.FP_len = FP_len
        if model_tag != '1024bool' and model_tag != '1024uint8' and model_tag != '2048bool':
            MyLogger.print_and_log(
                'Non-existent SCScore model requested: {}. Using "1024bool" model'.format(model_tag), scscore_prioritizer_loc, level=2)
            model_tag = '1024bool'
        filename = 'trained_model_path_'+model_tag
        with open(gc.SCScore_Prioritiaztion[filename], 'rb') as fid:
            self.vars = pickle.load(fid)
        if gc.DEBUG:
            MyLogger.print_and_log('Loaded synthetic complexity score prioritization model from {}'.format(
            gc.SCScore_Prioritiaztion[filename]), scscore_prioritizer_loc)

        if 'uint8' in gc.SCScore_Prioritiaztion[filename]:
            def mol_to_fp(mol):
                if mol is None:
                    return np.array((self.FP_len,), dtype=np.uint8)
                fp = AllChem.GetMorganFingerprint(
                    mol, self.FP_rad, useChirality=True)  # uitnsparsevect
                fp_folded = np.zeros((self.FP_len,), dtype=np.uint8)
                for k, v in fp.GetNonzeroElements().items():
                    fp_folded[k % self.FP_len] += v
                return np.array(fp_folded)
        else:
            def mol_to_fp(mol):
                if mol is None:
                    return np.zeros((self.FP_len,), dtype=np.float32)
                return np.array(AllChem.GetMorganFingerprintAsBitVect(mol, self.FP_rad, nBits=self.FP_len,
                                                                      useChirality=True), dtype=np.bool)
        self.mol_to_fp = mol_to_fp

        self.pricer = Pricer()
        self.pricer.load()
        self._restored = True
        self._loaded = True

    def smi_to_fp(self, smi):
        if not smi:
            return np.zeros((self.FP_len,), dtype=np.float32)
        return self.mol_to_fp(Chem.MolFromSmiles(str(smi)))

    def apply(self, x):
        if not self._restored:
            raise ValueError('Must restore model weights!')
        # Each pair of vars is a weight and bias term
        for i in range(0, len(self.vars), 2):
            last_layer = (i == (len(self.vars)-2))
            W = self.vars[i]
            b = self.vars[i+1]
            x = np.dot(W.T, x) + b        
            if not last_layer:
                x = x * (x > 0)  # ReLU
        x = 1 + (self.score_scale - 1) * sigmoid(x)
        return x

    def get_priority(self, retroProduct, **kwargs):
        mode = kwargs.get('mode', gc.max)
        if not self._loaded:
            self.load_model()

        if not isinstance(retroProduct, str):
            scores = []
            for smiles in retroProduct.smiles_list:
                scores.append(self.get_score_from_smiles(smiles))
            return -self.merge_scores(scores, mode=mode)
        else:
            return -self.get_score_from_smiles(retroProduct)
        if not retroProduct:
            return -inf

    def merge_scores(self, list_of_scores, mode=gc.max):
        if mode == gc.mean:
            return np.mean(list_of_scores)
        elif mode == gc.geometric:
            return np.power(np.prod(list_of_scores), 1.0/len(list_of_scores))
        elif mode == gc.pow8:
            pow8 = []
            for score in list_of_scores:
                pow8.append(8**score)
            return np.sum(pow8)
        else:
            return np.max(list_of_scores)

    def get_score_from_smiles(self, smiles, noprice=False):
        # Check buyable
        if not noprice:
            ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True)
            if ppg:
                return ppg / 100.
        
        fp = np.array((self.smi_to_fp(smiles)), dtype=np.float32)
        if sum(fp) == 0:
            cur_score = 0.
        else:
            # Run
            cur_score = self.apply(fp)
        return cur_score
Esempio n. 7
0
class MinCostPrecursorPrioritizer(Prioritizer):
    """A precursor prioritizer using a MinCost model.

    This is a standalone, importable MinCost model. Uses loaded keras model.

    Attributes:
        vars (list): Unused.
        FP_rad (int): Fingerprint radius.
        FP_len (int): Fingerprint length.
        score_scale (float): Upper-bound of scale for scoring.
        pricer (Pricer or None): Pricer instance to lookup chemical costs.
    """
    def __init__(self, score_scale=10.0):
        """Initializes MinCostPrecursorPrioritizer.

        Args:
            score_scale (float, optional): Upper-bound of scale for scoring.
                (default: {10})
        """
        self.vars = []
        self.FP_rad = 3
        self.score_scale = score_scale

        self._restored = False
        self.pricer = None
        self._loaded = False

    def load_model(self,
                   FP_len=1024,
                   input_layer=6144,
                   hidden_layer=512,
                   modelpath=""):
        """Loads MinCost model.

        Args:
            FP_len (int, optional): Fingerprint length. (default: {1024})
            input_layer (int, optional): ?? (default: {6144})
            hidden_layer (int, optional): ?? (default: {512})
            model_path (str, optional): Specifies file containing model.
                (default: {''})
        """
        self.FP_len = FP_len

        def last_layer(x):
            return (lambda x: 11 - 11 * K.exp(-K.abs(x / 1500000)))(x)

        model = Sequential()
        model.add(
            Dense(input_layer,
                  activation="relu",
                  batch_input_shape=(None, self.FP_len)))
        model.add(Dense(hidden_layer, activation="relu"))
        model.add(Dense(hidden_layer, activation="relu"))
        model.add(Dense(hidden_layer, activation="relu"))
        model.add(Dense(hidden_layer, activation="relu"))
        model.add(Dense(1))
        model.add(Lambda(last_layer, output_shape=(None, 1)))
        if modelpath == "":
            modelpath = gc.MinCost_Prioritiaztion['trained_model_path']
        weights = model.load_weights(modelpath)
        self.model = model

        # QUESTION: Can't this be defined at the class level in the first place?
        def mol_to_fp(mol):
            """Returns fingerprint of given molecule.

            Args:
                mol (Chem.rdchem.Mol or None): Molecule to get fingerprint
                    of.

            Returns:
                np.ndarray of np.bool or np.float32: Fingerprint of given
                    molecule.
            """
            if mol is None:
                return np.zeros((self.FP_len, ), dtype=np.float32)
            return np.array(AllChem.GetMorganFingerprintAsBitVect(
                mol, self.FP_rad, nBits=self.FP_len, useChirality=True),
                            dtype=np.bool)

        self.mol_to_fp = mol_to_fp

        self.pricer = Pricer()
        self.pricer.load()
        self._restored = True
        self._loaded = True

    def smi_to_fp(self, smi):
        """Returns fingerprint of molecule from given SMILES string.

        Args:
            smi (str): SMILES string of given molecule.
        """
        if not smi:
            return np.zeros((self.FP_len, ), dtype=np.float32)
        return self.mol_to_fp(Chem.MolFromSmiles(smi))

    def get_price(self, smi):
        """Gets price of given chemical.

        Args:
            smi (str): SMILES string for given chemical.

        Returns:
            float or None: 0.0 if price is available, None if not.
        """
        ppg = self.pricer.lookup_smiles(smi, alreadyCanonical=True)
        if ppg:
            return 0.0
        else:
            return None

    def get_priority(self, retroProduct, **kwargs):
        """Returns priority of given product based on MinCost model.

        Args:
            retroProduct (str or RetroPrecursor): Product to calculate score
                for.
            **kwargs: Additional optional arguments. Used for mode.

        Returns:
            float: Priority of given product.
        """
        mode = kwargs.pop('mode', gc.max)
        if not self._loaded:
            self.load_model()

        if not isinstance(retroProduct, str):
            scores = []
            for smiles in retroProduct.smiles_list:
                scores.append(self.get_score_from_smiles(smiles))
            return sum(scores)
        else:
            return self.get_score_from_smiles(retroProduct)
        if not retroProduct:
            return inf

    def get_score_from_smiles(self, smiles):
        """Gets precursor score from a given SMILES string.

        Args:
            smiles (str): SMILES string of precursor.

        Returns:
            float: Priority score of precursor.
        """
        # Check buyable
        ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True)
        if ppg:
            return 0.0  #ppg / 100.

        fp = np.array((self.smi_to_fp(smiles)), dtype=np.float32)
        if sum(fp) == 0:
            cur_score = 0.
        else:
            cur_score = self.model.predict(fp.reshape((1, self.FP_len)))[0][0]
        return cur_score
Esempio n. 8
0
class RelevanceHeuristicPrecursorPrioritizer(Prioritizer):
    """A precursor Prioritizer that uses a heuristic and template relevance.

    Attributes:
        pricer (Pricer or None): Used to look up chemical prices.
    """
    def __init__(self):
        """Initializes RelevanceHeuristicPrecursorPrioritizer."""
        self.pricer = None
        self._loaded = False

    def score_precursor(self, precursor):
        """Score a given precursor using a combination of the template relevance score and a heuristic rule

        Args:
            precursor (dict): dictionary of precursor to score

        Returns:
            float: combined relevance heuristic score of precursor
        """
        scores = []
        necessary_reagent_atoms = precursor['necessary_reagent'].count(
            '[') / 2.
        for smiles in precursor['smiles_split']:
            ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True)
            # If buyable, basically free
            if ppg:
                scores.append(-ppg / 1000.0)
                continue

            # Else, use heuristic
            mol = Chem.MolFromSmiles(smiles)
            total_atoms = mol.GetNumHeavyAtoms()
            ring_bonds = sum(
                [b.IsInRing() - b.GetIsAromatic() for b in mol.GetBonds()])
            chiral_centers = len(Chem.FindMolChiralCenters(mol))

            scores.append(-2.00 * np.power(total_atoms, 1.5) -
                          1.00 * np.power(ring_bonds, 1.5) -
                          2.00 * np.power(chiral_centers, 2.0))

        sco = np.sum(scores) - 4.00 * np.power(necessary_reagent_atoms, 2.0)
        return sco / precursor['template_score']

    def reorder_precursors(self, precursors):
        """Reorder a list of precursors by their newly computed combined relevance heuristic score

        Args:
            precursors (list of dict)
        
        Returns:
            list: reordered list of precursor dictionaries with new 'score' and 'rank' keys
        """
        scores = np.array([self.score_precursor(p) for p in precursors])
        indices = np.argsort(-scores)
        scores = scores[indices]
        result = []
        rank = 1
        for i, score in zip(indices, scores):
            result.append(precursors[i])
            result[-1]['score'] = score
            result[-1]['rank'] = rank
            rank += 1
        return result

    def get_priority(self, retroPrecursor, **kwargs):
        """Gets priority of given precursor based on heuristic and relevance.

        Args:
            retroPrecursor (RetroPrecursor): Precursor to calculate priority of.
            **kwargs: Unused.

        Returns:
            float: Priority score of precursor.
        """
        if not self._loaded:
            self.load_model()

        necessary_reagent_atoms = retroPrecursor.necessary_reagent.count(
            '[') / 2.
        scores = []
        for smiles in retroPrecursor.smiles_list:
            # If buyable, basically free
            ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True)
            if ppg:
                scores.append(-ppg / 1000.0)
                continue

            # Else, use heuristic
            x = Chem.MolFromSmiles(smiles)
            total_atoms = x.GetNumHeavyAtoms()
            ring_bonds = sum(
                [b.IsInRing() - b.GetIsAromatic() for b in x.GetBonds()])
            chiral_centers = len(Chem.FindMolChiralCenters(x))

            scores.append(-2.00 * np.power(total_atoms, 1.5) -
                          1.00 * np.power(ring_bonds, 1.5) -
                          2.00 * np.power(chiral_centers, 2.0))

        sco = np.sum(scores) - 4.00 * np.power(necessary_reagent_atoms, 2.0)
        return sco / retroPrecursor.template_score

    def load_model(self):
        """Loads the Pricer used in the heuristic priority scoring."""
        self.pricer = Pricer()
        self.pricer.load()
        self._loaded = True
Esempio n. 9
0
class TreeBuilder:
    def __init__(self,
                 retroTransformer=None,
                 pricer=None,
                 max_branching=20,
                 max_depth=3,
                 expansion_time=240,
                 celery=False,
                 nproc=1,
                 mincount=25,
                 chiral=True,
                 mincount_chiral=10,
                 template_prioritization=gc.relevance,
                 precursor_prioritization=gc.relevanceheuristic,
                 chemhistorian=None):
        """Class for retrosynthetic tree expansion using a depth-first search

        Initialization of an object of the TreeBuilder class sets default values
        for various settings and loads transformers as needed (i.e., based on 
        whether Celery is being used or not). Most settings are overridden
        by the get_buyable_paths method anyway.

        Keyword Arguments:
            retroTransformer {None or RetroTransformer} -- RetroTransformer object
                to be used for expansion when *not* using Celery. If none, 
                will be initialized using the model_loader.load_Retro_Transformer
                function (default: {None})
            pricer {Pricer} -- Pricer object to be used for checking stop criteria
                (buyability). If none, will be initialized using default settings
                from the global configuration (default: {None})
            max_branching {number} -- Maximum number of precursor suggestions to
                add to the tree at each expansion (default: {20})
            max_depth {number} -- Maximum number of reactions to allow before
                stopping the recursive expansion down one branch (default: {3})
            expansion_time {number} -- Time (in seconds) to allow for expansion
                before searching the generated tree for buyable pathways (default: {240})
            celery {bool} -- Whether or not Celery is being used. If True, then 
                the TreeBuilder relies on reservable retrotransformer workers
                initialized separately. If False, then retrotransformer workers
                will be spun up using multiprocessing (default: {False})
            nproc {number} -- Number of retrotransformer processes to fork for
                faster expansion (default: {1})
            mincount {number} -- Minimum number of precedents for an achiral template
                for inclusion in the template library. Only used when retrotransformers
                need to be initialized (default: {25})
            mincount_chiral {number} -- Minimum number of precedents for a chiral template
                for inclusion in the template library. Only used when retrotransformers
                need to be initialized. Chiral templates are necessarily more specific,
                so we generally use a lower threshold than achiral templates (default: {10})
            chiral {bool} -- Whether or not to pay close attention to chirality. When 
                False, even achiral templates can lead to accidental inversion of
                chirality in non-reacting parts of the molecule. It is highly 
                recommended to keep this as True (default: {True})
            template_prioritization {string} -- Strategy used for template
                prioritization, as a string. There are a limited number of available
                options - consult the global configuration file for info (default: {gc.popularity})
            precursor_prioritization {string} -- Strategy used for precursor
                prioritization, as a string. There are a limited number of available
                options - consult the global configuration file for info (default: {gc.heuristic})
        """
        # General parameters
        self.celery = celery
        self.max_branching = max_branching
        self.mincount = mincount
        self.mincount_chiral = mincount_chiral
        self.max_depth = max_depth
        self.expansion_time = expansion_time
        self.template_prioritization = template_prioritization
        self.precursor_prioritization = precursor_prioritization
        self.nproc = nproc
        self.chiral = chiral
        self.max_cum_template_prob = 1

        if pricer:
            self.pricer = pricer
        else:
            self.pricer = Pricer()
            self.pricer.load()

        self.chemhistorian = chemhistorian
        if chemhistorian is None:
            from makeit.utilities.historian.chemicals import ChemHistorian
            self.chemhistorian = ChemHistorian()
            self.chemhistorian.load_from_file(refs=False, compressed=True)

        self.reset()

        # When not using Celery, need to ensure retroTransformer initialized
        if not self.celery:
            if retroTransformer:
                self.retroTransformer = retroTransformer
            else:
                self.retroTransformer = model_loader.load_Retro_Transformer(
                    mincount=self.mincount,
                    mincount_chiral=self.mincount_chiral,
                    chiral=self.chiral)

        # Define method to check if all results processed
        if self.celery:

            def waiting_for_results():
                # update
                time.sleep(1)
                return self.pending_results != [] or self.is_ready != []
        else:

            def waiting_for_results():
                waiting = [
                    expansion_queue.empty()
                    for expansion_queue in self.expansion_queues
                ]
                waiting.append(self.results_queue.empty())
                waiting += self.idle

                return (not all(waiting))

        self.waiting_for_results = waiting_for_results

        # Define method to get a processed result.
        if self.celery:

            def get_ready_result():
                # Update which processes are ready
                self.is_ready = [
                    i for (i, res) in enumerate(self.pending_results)
                    if res.ready()
                ]

                for i in self.is_ready:
                    (smiles,
                     precursors) = self.pending_results[i].get(timeout=0.2)
                    self.pending_results[i].forget()
                    _id = self.chem_to_id[smiles]
                    yield (_id, smiles, precursors)
                self.pending_results = [
                    res for (i, res) in enumerate(self.pending_results)
                    if i not in self.is_ready
                ]
        else:

            def get_ready_result():
                while not self.results_queue.empty():
                    yield self.results_queue.get(0.2)

        self.get_ready_result = get_ready_result

        # Define method to start up parallelization
        # note: Celery will reserve an entire preforked pool of workers
        if self.celery:

            def prepare():
                try:
                    if self.chiral:
                        request = tb_c_worker.reserve_worker_pool.delay()
                        self.private_worker_queue = request.get(timeout=10)
                    else:
                        request = tb_worker.reserve_worker_pool.delay()
                        self.private_worker_queue = request.get(timeout=10)
                except Exception as e:
                    request.revoke()
                    raise IOError(
                        'Did not find an available pool of workers! Try again later ({})'
                        .format(e))
        else:

            def prepare():
                MyLogger.print_and_log(
                    'Tree builder spinning off {} child processes'.format(
                        self.nproc), treebuilder_loc)
                for i in range(self.nproc):
                    p = Process(target=self.work, args=(i, ))
                    self.workers.append(p)
                    p.start()

        self.prepare = prepare

        # Define method to stop working.
        if self.celery:

            def stop():
                if self.pending_results != []:
                    # OPTION 1 - REVOKE TASKS, WHICH GETS SENT TO ALL WORKERS REGARDLESS OF TYPE
                    #[res.revoke() for res in pending_results]
                    # OPTION 2 - DIRECTLY PURGE THE QUEUE (NOTE: HARDCODED FOR
                    # AMQP)
                    import celery.bin.amqp
                    from askcos_site.celery import app
                    amqp = celery.bin.amqp.amqp(app=app)
                    amqp.run('queue.purge', self.private_worker_queue)
                if self.chiral and self.private_worker_queue:
                    released = tb_c_worker.unreserve_worker_pool.apply_async(
                        queue=self.private_worker_queue, retry=True).get()
                elif self.private_worker_queue:
                    released = tb_worker.unreserve_worker_pool.apply_async(
                        queue=self.private_worker_queue, retry=True).get()
                self.running = False
        else:

            def stop():
                if not self.running:
                    return
                self.done.value = 1
                MyLogger.print_and_log('Terminating tree building process.',
                                       treebuilder_loc)

                for p in self.workers:
                    if p and p.is_alive():
                        p.terminate()
                MyLogger.print_and_log('All tree building processes done.',
                                       treebuilder_loc)
                self.running = False

        self.stop = stop

        # Define method to expand a single compound
        if self.celery:

            def expand(smiles, chem_id, depth):
                # Chiral transformation or heuristic prioritization requires
                # same database
                if self.chiral or self.template_prioritization == gc.relevance:
                    self.pending_results.append(
                        tb_c_worker.get_top_precursors.apply_async(
                            args=(smiles, self.template_prioritization,
                                  self.precursor_prioritization),
                            kwargs={
                                'mincount': self.mincount,
                                'max_branching': self.max_branching,
                                'template_count': self.template_count,
                                'mode': self.precursor_score_mode,
                                'max_cum_prob': self.max_cum_template_prob,
                                'apply_fast_filter': self.apply_fast_filter,
                                'filter_threshold': self.filter_threshold
                            },
                            # Prioritize higher depths: Depth first search.
                            priority=int(depth),
                            queue=self.private_worker_queue,
                        ))
                else:
                    self.pending_results.append(
                        tb_worker.get_top_precursors.apply_async(
                            args=(smiles, self.template_prioritization,
                                  self.precursor_prioritization),
                            kwargs={
                                'mincount': self.mincount,
                                'max_branching': self.max_branching,
                                'template_count': self.template_count,
                                'mode': self.precursor_score_mode,
                                'max_cum_prob': self.max_cum_template_prob,
                                'apply_fast_filter': self.apply_fast_filter,
                                'filter_threshold': self.filter_threshold
                            },
                            # Prioritize higher depths: Depth first search.
                            priority=int(depth),
                            queue=self.private_worker_queue,
                        ))
        else:

            def expand(smiles, chem_id, depth):
                #print('Coordinator put {} (ID {}) in queue queue {}'.format(smiles, chem_id, depth))
                self.expansion_queues[depth].put((chem_id, smiles))

        self.expand = expand

        # Define how first target is set.
        if self.celery:

            def set_initial_target(smiles):
                self.expand(smiles, 1, 0)
        else:

            def set_initial_target(smiles):
                self.expansion_queues[-1].put((1, smiles))
                print((self.expansion_queues))
                print('Put something on expansion queue')
                while self.results_queue.empty():
                    time.sleep(0.25)
                    #print('Waiting for first result in treebuilder...')

        self.set_initial_target = set_initial_target

    def reset(self):
        """Clears saved state and resets counters

        Called after initialization and after getting buyable pathways
        to free up memory and - in the case of Celery - be prepared to
        handle another request without having results carry over between
        different tree building requests.
        """
        if self.celery:
            # general parameters in celery format
            self.tree_dict = {}
            self.chem_to_id = {}
            self.buyable_leaves = set()
            self.current_id = 2
            self.is_ready = []
            # specifically for celery
            self.pending_results = []
            self.private_worker_queue = None
        else:
            # general parameters in python multiprocessing format
            self.manager = Manager()
            self.tree_dict = self.manager.dict()
            self.chem_to_id = self.manager.dict()
            self.buyable_leaves = self.manager.list()
            self.current_id = self.manager.Value('i', 2)

            # specificly for python multiprocessing
            self.done = self.manager.Value('i', 0)
            self.paused = self.manager.Value('i', 0)
            # Keep track of idle workers
            self.idle = self.manager.list()
            self.results_queue = Queue()
            self.workers = []
            self.coordinator = None
            self.running = False

    def get_children(self, precursors):
        """Reformats result of an expansion for tree buidler

        Arguments:
            precursors {list of dicts} -- precursor information as a list of
                dictionaries, where each dictionary contains information about
                the precursor identity and auxiliary information like
                whether the template requires heavy atoms to be contributed
                by reagents.

        Returns:
            [list of (dict, list)] -- precursor information reformatted into
                a list of 2-tuples, where the 2-tuple consists of a dictionary
                containing basically the same information as the precursor 
                dictionaries, as well as the smiles_split of the precursor,
                which is a list of reactant SMILES strings
        """
        children = []
        for precursor in precursors:
            children.append(({
                'tforms':
                precursor['tforms'],
                'template':
                precursor['tforms'][0],
                'template_score':
                precursor['template_score'],
                'necessary_reagent':
                precursor['necessary_reagent'],
                'num_examples':
                precursor['num_examples'],
                'score':
                precursor['score'],
                'plausibility':
                precursor['plausibility']
            }, precursor['smiles_split']))

        return children

    def add_children(self, children, smiles, unique_id):
        """Add results of one expansion to the tree dictionary

        Arguments:
            children {list of (dict, list)} -- list of candidate disconnections
                for the target SMILES, formatted as returned by get_children()
            smiles {string} -- SMILES string of product (target) molecule
            unique_id {integer >= 1} -- ID of product (target) molecule in the
                tree dictionary
        """

        parent_chem_doc = self.tree_dict[unique_id]  # copy to overwrite later
        parent_chem_prod_of = parent_chem_doc['prod_of']
        # Assign unique number
        for (rxn, mols) in children:

            # Add option to leave out blacklisted reactions.
            rxn_smiles = '.'.join(sorted(mols)) + '>>' + smiles
            if rxn_smiles in self.known_bad_reactions:
                continue

            # What should be excluded?
            skip_this = False
            for mol in mols:
                # Exclude banned molecules too
                if mol in self.forbidden_molecules:
                    skip_this = True
                # Exclude reactions where the reactant is the target
                if mol == self.tree_dict[1]['smiles']:
                    skip_this = True
            if skip_this:
                continue

            # depending on whether current_id was given as 'Manager.Value' type
            # or 'Integer':
            if self.celery:
                rxn_id = self.current_id
                self.current_id += 1
            else:
                rxn_id = self.current_id.value
                # this is only okay because there is/should be only ONE
                # treebuilder
                self.current_id.value += 1
            # For the parent molecule, record child reactions
            parent_chem_prod_of.append(rxn_id)

            # For the reaction, keep track of children IDs
            chem_ids = []
            for mol in mols:

                # New chemical?
                if mol not in self.chem_to_id:

                    try:
                        chem_id = self.current_id.value
                        # this is only okay because there is/should be only ONE
                        # treebuilder
                        self.current_id.value += 1
                    except AttributeError:
                        chem_id = self.current_id
                        self.current_id += 1

                    # Check if buyable
                    ppg = self.pricer.lookup_smiles(mol, alreadyCanonical=True)
                    hist = self.chemhistorian.lookup_smiles(
                        mol, alreadyCanonical=True)

                    self.tree_dict[chem_id] = {
                        'smiles': mol,
                        'prod_of': [],
                        'rct_of': [rxn_id],
                        'depth': parent_chem_doc['depth'] + 1,
                        'ppg': ppg,
                        'as_reactant': hist['as_reactant'],
                        'as_product': hist['as_product'],
                    }
                    self.chem_to_id[mol] = chem_id

                    # Check stop criterion
                    if self.is_a_leaf_node(mol, ppg, hist):
                        # print('{} is a leaf!'.format(mol))
                        if self.celery:
                            self.buyable_leaves.add(chem_id)
                        else:
                            self.buyable_leaves.append(chem_id)

                    else:
                        # Add to queue to get expanded
                        if parent_chem_doc['depth'] >= self.max_depth - 1:
                            if gc.DEBUG:
                                MyLogger.print_and_log(
                                    'Reached maximum depth, so will not expand around {}'
                                    .format(self.tree_dict[chem_id]),
                                    treebuilder_loc)
                        else:
                            self.expand(mol, chem_id, parent_chem_doc['depth'])

                else:
                    chem_id = self.chem_to_id[mol]

                    # Overwrite this chemical node to record it is a reactant
                    # of this rxn
                    chem_doc = self.tree_dict[chem_id]
                    chem_doc['rct_of'] += [rxn_id]
                    self.tree_dict[chem_id] = chem_doc

                # Save ID
                chem_ids.append(chem_id)

            # Record by overwriting the whole dict value
            rxn['rcts'] = chem_ids
            rxn['prod'] = unique_id
            rxn['depth'] = parent_chem_doc['depth'] + 0.5
            self.tree_dict[rxn_id] = rxn

        # Overwrite dictionary entry for the parent
        parent_chem_doc['prod_of'] = parent_chem_prod_of
        self.tree_dict[unique_id] = parent_chem_doc

    def work(self, i):
        """Work function for retroTransformer processes

        Only used when Celery is false and multiprocessing is used instead. Will
        constantly look for molecules on expansion queues to expand (in a DFS) 
        and add results to the results queue.

        Arguments:
            i {integer >= 0} -- index assigned to the worker, used to assign
                idle status to the shared list self.idle[i]
        """
        while True:
            # If done, stop
            if self.done.value:
                MyLogger.print_and_log(
                    'Worker {} saw done signal, terminating'.format(i),
                    treebuilder_loc)
                break
            # If paused, wait and check again
            if self.paused.value:
                #print('Worker {} saw pause signal, sleeping for 1 second'.format(i))
                time.sleep(1)
                continue
            # Grab something off the queue
            for j in range(len(self.expansion_queues))[::-1]:
                try:
                    (_id, smiles) = self.expansion_queues[j].get(
                        timeout=0.1)  # short timeout
                    self.idle[i] = False
                    # print('Worker {} grabbed {} (ID {}) to expand from queue {}'.format(i, smiles, _id, j))
                    result = self.retroTransformer.get_outcomes(
                        smiles,
                        self.mincount, (self.precursor_prioritization,
                                        self.template_prioritization),
                        template_count=self.template_count,
                        mode=self.precursor_score_mode,
                        max_cum_prob=self.max_cum_template_prob,
                        apply_fast_filter=self.apply_fast_filter,
                        filter_threshold=self.filter_threshold)

                    precursors = result.return_top(n=self.max_branching)

                    self.results_queue.put((_id, smiles, precursors))

                except VanillaQueue.Empty:
                    #print('Queue {} empty for worker {}'.format(j, i))
                    pass
                except Exception as e:
                    sys.stdout.write(str(e))
                    sys.stdout.flush()
            time.sleep(0.01)
            self.idle[i] = True

    def coordinate(self):
        """Run the expansion up to a specified self.expansion_time (seconds)
        """
        start_time = time.time()
        elapsed_time = time.time() - start_time
        next = 1
        while (elapsed_time <
               self.expansion_time) and self.waiting_for_results():
            if (int(elapsed_time) / 10 == next):
                next += 1
                MyLogger.print_and_log(
                    'Worked for {}/{} s'.format(
                        int(elapsed_time * 10) / 10.0, self.expansion_time),
                    treebuilder_loc)
            try:
                for (_id, smiles, precursors) in self.get_ready_result():
                    children = self.get_children(precursors)
                    self.add_children(children, smiles, _id)
                elapsed_time = time.time() - start_time
            except Exception as e:
                elapsed_time = time.time() - start_time
                print(('##ERROR#: {}'.format(e)))

    def build_tree(self, target):
        """Recursively build out the synthesis tree

        Arguments:
            target {string} -- SMILES of target molecule
        """
        self.running = True
        if self.celery:
            from celery.result import allow_join_result
        else:
            from makeit.utilities.with_dummy import with_dummy as allow_join_result
        with allow_join_result():
            try:
                hist = self.chemhistorian.lookup_smiles(target)
                self.tree_dict[1] = {
                    'smiles': target,
                    'prod_of': [],
                    'rct_of': [],
                    'depth': 0,
                    'ppg': self.pricer.lookup_smiles(target),
                    'as_reactant': hist['as_reactant'],
                    'as_product': hist['as_product'],
                }

                if self.is_a_leaf_node(target, self.tree_dict[1]['ppg'], hist):
                    if self.celery:
                        self.buyable_leaves.add(1)
                    else:
                        self.buyable_leaves.append(1)

                self.chem_to_id[target] = 1

                self.prepare()
                self.set_initial_target(target)
                self.coordinate()

            finally:  # make sure stop is graceful
                self.stop()

    def tree_status(self):
        """Summarize size of tree after expansion

        Returns:
            num_chemicals {int} -- number of chemical nodes in the tree
            num_reactions {int} -- number of reaction nodes in the tree
            at_depth {dict} -- dictionary containing counts at each integer
                depth (chemicals) and half-integer depth (reactions)
        """

        num_chemicals = 0
        num_reactions = 0
        at_depth = {}
        for _id in list(self.tree_dict.keys()):
            depth = self.tree_dict[_id]['depth']
            if depth % 1 == 0:
                num_chemicals += 1
            else:
                num_reactions += 1
            if depth not in at_depth:
                at_depth[depth] = 1
            else:
                at_depth[depth] += 1
        return (num_chemicals, num_reactions, at_depth)

    def get_buyable_paths(self,
                          target,
                          max_depth=3,
                          max_branching=25,
                          expansion_time=240,
                          template_prioritization=gc.relevance,
                          precursor_prioritization=gc.heuristic,
                          nproc=1,
                          mincount=25,
                          chiral=True,
                          mincount_chiral=10,
                          max_trees=25,
                          max_ppg=1e10,
                          known_bad_reactions=[],
                          forbidden_molecules=[],
                          template_count=100,
                          precursor_score_mode=gc.max,
                          max_cum_template_prob=1,
                          max_natom_dict=defaultdict(lambda: 1e9,
                                                     {'logic': None}),
                          min_chemical_history_dict={
                              'as_reactant': 1e9,
                              'as_product': 1e9,
                              'logic': None
                          },
                          apply_fast_filter=False,
                          filter_threshold=0.5):
        """Get viable synthesis trees using an iterative deepening depth-first search

        [description]

        Arguments:
            target {[type]} -- [description]

        Keyword Arguments:
            max_depth {number} -- Maximum number of reactions to allow before
                stopping the recursive expansion down one branch (default: {3})
            max_branching {number} -- Maximum number of precursor suggestions to
                add to the tree at each expansion (default: {25})
            expansion_time {number} -- Time (in seconds) to allow for expansion
                before searching the generated tree for buyable pathways (default: {240})
            nproc {number} -- Number of retrotransformer processes to fork for
                faster expansion (default: {1})
            mincount {number} -- Minimum number of precedents for an achiral template
                for inclusion in the template library. Only used when retrotransformers
                need to be initialized (default: {25})
            mincount_chiral {number} -- Minimum number of precedents for a chiral template
                for inclusion in the template library. Only used when retrotransformers
                need to be initialized. Chiral templates are necessarily more specific,
                so we generally use a lower threshold than achiral templates (default: {10})
            chiral {bool} -- Whether or not to pay close attention to chirality. When 
                False, even achiral templates can lead to accidental inversion of
                chirality in non-reacting parts of the molecule. It is highly 
                recommended to keep this as True (default: {True})
            template_prioritization {string} -- Strategy used for template
                prioritization, as a string. There are a limited number of available
                options - consult the global configuration file for info (default: {gc.relevance})
            precursor_prioritization {string} -- Strategy used for precursor
                prioritization, as a string. There are a limited number of available
                options - consult the global configuration file for info (default: {gc.heuristic})
            max_trees {number} -- Maximum number of buyable trees to return. Does
                not affect expansion time (default: {25})
            max_ppg {number} -- Maximum price ($/g) for a chemical to be considered
                buyable, and thus potentially usable as a leaf node (default: {1e10})
            known_bad_reactions {list} -- Reactions to forbid during expansion, 
                represented as list of reaction SMILES strings. Each reaction SMILES
                must be canonicalized, have atom mapping removed, and have its
                reactant fragments be sorted. Forbidden reactions are checked 
                when processing children returned by the RetroTransformer (default: {[]})
            forbidden_molecules {list} -- Molecules to forbid during expansion, 
                represented as a list of SMILES strings. Each string must be
                canonicalized without atom mapping. Forbidden molecules will not
                be allowed as intermediates or leaf nodes (default: {[]})
            template_count {number} -- Maximum number of templates to apply at
                each expansion (default: {100})
            precursor_score_mode {string} -- Mode to use for precursor scoring
                when using the SCScore prioritizer and multiple reactant
                fragments must be scored together (default: {gc.max})
            max_cum_template_prob {number} -- Maximum cumulative template
                probability (i.e., relevance score), used as part of the
                relevance template_prioritizer (default: {1})
            max_natom_dict {dict} -- Dictionary defining a potential chemical
                property stopping criterion based on the number of atoms of
                C, N, O, and H in a molecule. The 'logic' keyword of the dict
                refers to how that maximum number of atom information is combined
                with the requirement that chemicals be cheaper than max_ppg
                (default: {defaultdict(lambda: 1e9, {'logic': None})})
            min_chemical_history_dict {dict} -- Dictionary defining a potential
                chemical stopping criterion based on the number of times a
                molecule has been seen previously. Always uses logical OR

        Returns:
            tree_status -- result of tree_status()
            trees -- list of dictionaries, where each dictionary defines a
                synthetic route
        """

        self.mincount = mincount
        self.mincount_chiral = mincount_chiral
        self.max_depth = max_depth
        self.max_branching = max_branching
        self.expansion_time = expansion_time
        self.template_prioritization = template_prioritization
        self.precursor_prioritization = precursor_prioritization
        self.precursor_score_mode = precursor_score_mode
        self.nproc = nproc
        self.template_count = template_count
        self.max_cum_template_prob = max_cum_template_prob
        self.max_ppg = max_ppg
        self.apply_fast_filter = apply_fast_filter
        self.filter_threshold = filter_threshold

        MyLogger.print_and_log(
            'Starting to expand {} using max_natom_dict {}, min_history {}'.
            format(target, max_natom_dict, min_chemical_history_dict),
            treebuilder_loc,
            level=1)

        if min_chemical_history_dict['logic'] not in [None, 'none'] and \
                self.chemhistorian is None:
            from makeit.utilities.historian.chemicals import ChemHistorian
            self.chemhistorian = ChemHistorian()
            self.chemhistorian.load_from_file(refs=False, compressed=True)
            MyLogger.print_and_log('Loaded compressed chemhistorian from file',
                                   treebuilder_loc,
                                   level=1)

        # Define stop criterion
        def is_buyable(ppg):
            return ppg and (ppg <= self.max_ppg)

        def is_small_enough(smiles):
            # Get structural properties
            natom_dict = defaultdict(lambda: 0)
            mol = Chem.MolFromSmiles(smiles)
            if not mol:
                return False
            for a in mol.GetAtoms():
                natom_dict[a.GetSymbol()] += 1
            natom_dict['H'] = sum(a.GetTotalNumHs() for a in mol.GetAtoms())
            max_natom_satisfied = all(natom_dict[k] <= v
                                      for (k,
                                           v) in list(max_natom_dict.items())
                                      if k != 'logic')
            return max_natom_satisfied

        def is_popular_enough(hist):
            return hist['as_reactant'] >= min_chemical_history_dict['as_reactant'] or \
                    hist['as_product'] >= min_chemical_history_dict['as_product']

        if min_chemical_history_dict['logic'] in [None, 'none']:
            if max_natom_dict['logic'] in [None, 'none']:

                def is_a_leaf_node(smiles, ppg, hist):
                    return is_buyable(ppg)
            elif max_natom_dict['logic'] == 'or':

                def is_a_leaf_node(smiles, ppg, hist):
                    return is_buyable(ppg) or is_small_enough(smiles)
            else:

                def is_a_leaf_node(smiles, ppg, hist):
                    return is_buyable(ppg) and is_small_enough(smiles)
        else:
            if max_natom_dict['logic'] in [None, 'none']:

                def is_a_leaf_node(smiles, ppg, hist):
                    return is_buyable(ppg) or is_popular_enough(hist)
            elif max_natom_dict['logic'] == 'or':

                def is_a_leaf_node(smiles, ppg, hist):
                    return is_buyable(ppg) or is_popular_enough(
                        hist) or is_small_enough(smiles)
            else:

                def is_a_leaf_node(smiles, ppg, hist):
                    return is_popular_enough(hist) or (is_buyable(ppg) and
                                                       is_small_enough(smiles))

        self.is_a_leaf_node = is_a_leaf_node

        # Override: if relevance method is used, chiral database must be used!
        if chiral or template_prioritization == gc.relevance:
            self.chiral = True
        if template_prioritization == gc.relevance and not self.celery:
            if not (self.retroTransformer.mincount == 25
                    and self.retroTransformer.mincount_chiral == 10
                    and self.retroTransformer.chiral):
                MyLogger.print_and_log(
                    'When using relevance based template prioritization, chiral template database '
                    +
                    'must be used with mincount = 25 and mincount_chiral = 10. Exiting...',
                    treebuilder_loc,
                    level=3)

        self.known_bad_reactions = known_bad_reactions
        self.forbidden_molecules = forbidden_molecules
        self.reset()

        # Initialize multiprocessing queues if necessary
        if not self.celery:
            for i in range(nproc):
                self.idle.append(True)
            if self.max_depth != 1:
                self.expansion_queues = [
                    Queue() for i in range(self.max_depth - 1)
                ]
            else:
                self.expansion_queues = [Queue()]

        # Generate trees
        self.build_tree(target)

        def IDDFS():
            """Perform an iterative deepening depth-first search to find buyable
            pathways.
                        
            Yields:
                nested dictionaries defining synthesis trees
            """
            for depth in range(self.max_depth + 1):
                for path in DLS_chem(1, depth, headNode=True):
                    yield chem_dict(1, children=path, **self.tree_dict[1])

        def DLS_chem(chem_id, depth, headNode=False):
            """Expand at a fixed depth for the current node chem_id.
            
            Arguments:
                chem_id {int >= 1} -- unique ID of the current chemical
                depth {int >= 0} -- current depth of the search
            
            Keyword Arguments:
                headNode {bool} -- Whether this is the first (head) node, in which
                    case it must be expanded even if it is buyable itself (default: {False})

            Yields:
                list -- paths connecting to buyable molecules that are children
                    of the current chemical
            """
            # Copy list so each new branch has separate list.

            if depth <= 0:
                # Not allowing deeper - is this buyable?
                if chem_id in self.buyable_leaves:
                    yield [
                    ]  # viable node, calling function doesn't need children
            else:
                # Do we need to go deeper?
                if chem_id in self.buyable_leaves and not headNode:
                    yield []  # Nope, this is a viable node
                else:
                    # Try going deeper via DLS_rxn function
                    for rxn_id in self.tree_dict[chem_id]['prod_of']:
                        rxn_smiles = '.'.join(
                            sorted([
                                self.tree_dict[x]['smiles']
                                for x in self.tree_dict[rxn_id]['rcts']
                            ])) + '>>' + self.tree_dict[chem_id]['smiles']
                        for path in DLS_rxn(rxn_id, depth):
                            yield [
                                rxn_dict(rxn_id,
                                         rxn_smiles,
                                         children=path,
                                         **self.tree_dict[rxn_id])
                            ]

        def DLS_rxn(rxn_id, depth):
            """Return children paths starting from a specific rxn_id
            
            Arguments:
                rxn_id {int >= 2} -- unique ID of this reaction in the tree_dict
                depth {int >= 0} -- current depth of the search
            
            Yields:
                list -- paths connecting to buyable molecules that are children
                    of the current reaction
            """

            # Only one reactant? easy!
            if len(self.tree_dict[rxn_id]['rcts']) == 1:
                chem_id = self.tree_dict[rxn_id]['rcts'][0]
                for path in DLS_chem(chem_id, depth - 1):
                    yield [
                        chem_dict(chem_id,
                                  children=path,
                                  **self.tree_dict[chem_id])
                    ]

            # Two reactants? want to capture all combinations of each node's
            # options
            elif len(self.tree_dict[rxn_id]['rcts']) == 2:
                chem_id0 = self.tree_dict[rxn_id]['rcts'][0]
                chem_id1 = self.tree_dict[rxn_id]['rcts'][1]
                for path0 in DLS_chem(chem_id0, depth - 1):
                    for path1 in DLS_chem(chem_id1, depth - 1):
                        yield [
                            chem_dict(chem_id0,
                                      children=path0,
                                      **self.tree_dict[chem_id0]),
                            chem_dict(chem_id1,
                                      children=path1,
                                      **self.tree_dict[chem_id1]),
                        ]

            # Three reactants? This is not elegant...
            elif len(self.tree_dict[rxn_id]['rcts']) == 3:
                chem_id0 = self.tree_dict[rxn_id]['rcts'][0]
                chem_id1 = self.tree_dict[rxn_id]['rcts'][1]
                chem_id2 = self.tree_dict[rxn_id]['rcts'][2]
                for path0 in DLS_chem(chem_id0, depth - 1):
                    for path1 in DLS_chem(chem_id1, depth - 1):
                        for path2 in DLS_chem(chem_id2, depth - 1):
                            yield [
                                chem_dict(chem_id0,
                                          children=path0,
                                          **self.tree_dict[chem_id0]),
                                chem_dict(chem_id1,
                                          children=path1,
                                          **self.tree_dict[chem_id1]),
                                chem_dict(chem_id2,
                                          children=path2,
                                          **self.tree_dict[chem_id2]),
                            ]

            # I am ashamed
            elif len(self.tree_dict[rxn_id]['rcts']) == 4:
                chem_id0 = self.tree_dict[rxn_id]['rcts'][0]
                chem_id1 = self.tree_dict[rxn_id]['rcts'][1]
                chem_id2 = self.tree_dict[rxn_id]['rcts'][2]
                chem_id3 = self.tree_dict[rxn_id]['rcts'][3]
                for path0 in DLS_chem(chem_id0, depth - 1):
                    for path1 in DLS_chem(chem_id1, depth - 1):
                        for path2 in DLS_chem(chem_id2, depth - 1):
                            for path3 in DLS_chem(chem_id3, depth - 1):
                                yield [
                                    chem_dict(chem_id0,
                                              children=path0,
                                              **self.tree_dict[chem_id0]),
                                    chem_dict(chem_id1,
                                              children=path1,
                                              **self.tree_dict[chem_id1]),
                                    chem_dict(chem_id2,
                                              children=path2,
                                              **self.tree_dict[chem_id2]),
                                    chem_dict(chem_id3,
                                              children=path3,
                                              **self.tree_dict[chem_id3]),
                                ]

            else:
                print('Too many reactants! Only have cases 1-4 programmed')
                raise ValueError(
                    'Too many reactants! Only have cases 1-4 programmed')

        # Generate paths and ensure unique
        import hashlib
        import json
        done_trees = set()
        trees = []
        counter = 0
        for tree in IDDFS():
            hashkey = hashlib.sha1(
                json.dumps(tree, sort_keys=True).encode('utf-8')).hexdigest()

            if hashkey in done_trees:
                #print('Found duplicate tree...')
                continue

            done_trees.add(hashkey)
            trees.append(tree)
            counter += 1

            if counter == max_trees:
                MyLogger.print_and_log(
                    'Generated {} trees (max_trees met), stopped looking for more...'
                    .format(max_trees), treebuilder_loc)
                break

        tree_status = self.tree_status()
        if self.celery:
            self.reset()  # free up memory, don't hold tree
        return (tree_status, trees)
Esempio n. 10
0
class MCTS:
    def __init__(self,
                 retroTransformer=None,
                 pricer=None,
                 max_branching=50,
                 total_applied_templates=1000,
                 max_depth=10,
                 celery=False,
                 nproc=8,
                 mincount=25,
                 chiral=True,
                 template_prioritization=gc.relevance,
                 precursor_prioritization=gc.heuristic,
                 max_ppg=100,
                 mincount_chiral=10):

        self.celery = celery
        self.mincount = mincount
        self.mincount_chiral = mincount_chiral
        self.max_depth = max_depth
        self.max_branching = max_branching
        self.total_applied_templates = total_applied_templates

        self.template_prioritization = template_prioritization
        self.precursor_prioritization = precursor_prioritization
        self.nproc = nproc
        self.chiral = chiral
        self.max_cum_template_prob = 1

        ## Pricer
        if pricer:
            self.pricer = pricer
        else:
            self.pricer = Pricer()
        self.pricer.load(max_ppg=max_ppg)

        self.reset()

        ## Load transformer
        '''
		try:
			from makeit.utilities.io import model_loader
			if not self.celery:
				if retroTransformer:
					self.retroTransformer = retroTransformer
				else:
					self.retroTransformer = model_loader.load_Retro_Transformer(mincount=self.mincount,
			                                                            		mincount_chiral=self.mincount_chiral,
			                                                            		chiral=self.chiral)
		
		except:
		'''
        # model_loader tries to load mpl, don't have/want it on the cluster ...
        # classical load then.

        db_client = MongoClient(gc.MONGO['path'],
                                gc.MONGO['id'],
                                connect=gc.MONGO['connect'])
        TEMPLATE_DB = db_client[gc.RETRO_TRANSFORMS_CHIRAL['database']][
            gc.RETRO_TRANSFORMS_CHIRAL['collection']]
        self.retroTransformer = RetroTransformer(
            mincount=self.mincount,
            mincount_chiral=self.mincount_chiral,
            TEMPLATE_DB=TEMPLATE_DB)
        self.retroTransformer.chiral = self.chiral

        home = os.path.expanduser('~')
        if home.split("/")[1] == "rigel":
            home = "/rigel/cheme/users/jss2278/chemical_networks"
        transformer_filepath = home + "/Make-It/makeit/data/"
        if os.path.isfile(transformer_filepath + "chiral_templates.pickle"):
            self.retroTransformer.load_from_file(
                True,
                "chiral_templates.pickle",
                chiral=self.chiral,
                rxns=True,
                file_path=transformer_filepath)
        else:
            self.retroTransformer.dump_to_file(True,
                                               "chiral_templates.pickle",
                                               chiral=self.chiral,
                                               file_path=transformer_filepath)

        if self.celery:

            def expand(smiles, chem_id, queue_depth, branching):
                # Chiral transformation or heuristic prioritization requires
                # same database
                if self.chiral or self.template_prioritization == gc.relevance:
                    self.pending_results.append(
                        tb_c_worker.get_top_precursors.apply_async(
                            args=(smiles, self.template_prioritization,
                                  self.precursor_prioritization),
                            kwargs={
                                'mincount': self.mincount,
                                'max_branching': self.max_branching,
                                'template_count': self.template_count,
                                'mode': self.precursor_score_mode,
                                'max_cum_prob': self.max_cum_template_prob
                            },
                            # Prioritize higher depths: Depth first search.
                            priority=int(depth),
                            queue=self.private_worker_queue,
                        ))
                else:
                    self.pending_results.append(
                        tb_worker.get_top_precursors.apply_async(
                            args=(smiles, self.template_prioritization,
                                  self.precursor_prioritization),
                            kwargs={
                                'mincount': self.mincount,
                                'max_branching': self.max_branching,
                                'template_count': self.template_count,
                                'mode': self.precursor_score_mode,
                                'max_cum_prob': self.max_cum_template_prob
                            },
                            # Prioritize higher depths: Depth first search.
                            priority=int(depth),
                            queue=self.private_worker_queue,
                        ))
        else:

            def expand(_id, chem_smi, depth, branching):
                self.expansion_queues[_id].put(
                    (_id, chem_smi, depth, branching))

        self.expand = expand

        # Define method to start up parallelization.
        if self.celery:

            def prepare():
                if self.chiral:
                    self.private_worker_queue = tb_c_worker.reserve_worker_pool.delay(
                    ).get(timeout=5)
                else:
                    self.private_worker_queue = tb_worker.reserve_worker_pool.delay(
                    ).get(timeout=5)
        else:

            def prepare():
                print 'Tree builder spinning off {} child processes'.format(
                    self.nproc)
                #MyLogger.print_and_log('Tree builder spinning off {} child processes'.format(
                #	self.nproc), treebuilder_loc)
                for i in range(self.nproc):
                    p = Process(target=self.work, args=(i, ))
                    self.workers.append(p)
                    p.start()

        self.prepare = prepare

        # Define method to check if all results processed
        if self.celery:

            def waiting_for_results():
                # update
                time.sleep(1)
                return self.pending_results != [] or self.is_ready != []
        else:

            def waiting_for_results():
                waiting = [
                    expansion_queue.empty()
                    for expansion_queue in self.expansion_queues
                ]
                for results_queue in self.results_queues:
                    waiting.append(results_queue.empty())
                waiting += self.idle
                return (not all(waiting))

        self.waiting_for_results = waiting_for_results

        # Define method to get a processed result.
        if self.celery:

            def get_ready_result():
                #Update which processes are ready
                self.is_ready = [
                    i for (i, res) in enumerate(self.pending_results)
                    if res.ready()
                ]
                for i in self.is_ready:
                    (smiles,
                     precursors) = self.pending_results[i].get(timeout=0.25)
                    self.pending_results[i].forget()
                    _id = self.chem_to_id[smiles]
                    yield (_id, smiles, precursors)
                self.pending_results = [
                    res for (i, res) in enumerate(self.pending_results)
                    if i not in self.is_ready
                ]
        else:

            def get_ready_result():
                for results_queue in self.results_queues:
                    while not results_queue.empty():
                        yield results_queue.get(timeout=0.25)

        self.get_ready_result = get_ready_result

        # Define method to get a signal to start a new attempt.
        if self.celery:

            def get_pathway_result():
                #Update which processes are ready
                self.is_ready = [
                    i for (i, res) in enumerate(self.pending_results)
                    if res.ready()
                ]
                for i in self.is_ready:
                    (smiles,
                     precursors) = self.pending_results[i].get(timeout=0.25)
                    self.pending_results[i].forget()
                    _id = self.chem_to_id[smiles]
                    yield (_id, smiles, precursors)
                self.pending_results = [
                    res for (i, res) in enumerate(self.pending_results)
                    if i not in self.is_ready
                ]
        else:

            def get_pathway_result():
                while not self.pathways_queue.empty():
                    yield self.pathways_queue.get(timeout=0.25)

        self.get_pathway_result = get_pathway_result

        # Define how first target is set.
        if self.celery:

            def set_initial_target(_id, smiles):
                self.expand(smiles, 1)
        else:

            def set_initial_target(_id, leaves):
                for leaf in leaves:
                    self.active_chemicals[_id].append(leaf)
                    self.expand_products(_id, [leaf], self.expansion_branching)

        self.set_initial_target = set_initial_target

        # Define method to stop working.
        if self.celery:

            def stop():
                if self.pending_results != []:
                    import celery.bin.amqp
                    from askcos_site.celery import app
                    amqp = celery.bin.amqp.amqp(app=app)
                    amqp.run('queue.purge', self.private_worker_queue)
                if self.chiral:
                    released = tb_c_worker.unreserve_worker_pool.apply_async(
                        queue=self.private_worker_queue, retry=True).get()
                else:
                    released = tb_worker.unreserve_worker_pool.apply_async(
                        queue=self.private_worker_queue, retry=True).get()
                self.running = False
        else:

            def stop():
                if not self.running:
                    return
                self.done.value = 1
                #MyLogger.print_and_log('Terminating tree building process.', treebuilder_loc)
                for p in self.workers:
                    if p and p.is_alive():
                        p.terminate()
                #MyLogger.print_and_log('All tree building processes done.', treebuilder_loc)
                self.running = False

        self.stop = stop

    def get_price(self, chem_smi):
        ppg = self.pricer.lookup_smiles(chem_smi, alreadyCanonical=True)
        if ppg:
            return 0.0
        else:
            return None

    def update_tree(self, _id):
        try:
            self.pathway_count += 1
            chemicals = self.pathways[_id]['chemical_nodes']
            reactions = self.pathways[_id]['reaction_nodes']
            target_smiles = self.pathways[_id]['target']
            smiles_id = self.pathways[_id]['smiles_id']

            # Add in the penalties to the 'purchase price' so they get counted right in Mincost
            for key, C in chemicals.items():
                if C.retro_results == []:
                    C.price(self.max_penalty)
                    continue
                if key[1] == self.max_depth:
                    C.price(self.depth_penalty)
                    continue

            # Update costs / successes
            Reset(chemicals, reactions)
            MinCost((target_smiles, 0), self.max_depth, chemicals, reactions)
            target_cost = self.pathways[_id]['chemical_nodes'][(target_smiles,
                                                                0)].cost

            buyable = True
            for chem_key in chemicals:
                if len(chemicals[chem_key].incoming_reactions) == 0:
                    if not (chemicals[chem_key].purchase_price == 0.0):
                        buyable = False
            if buyable:
                self.successful_pathway_count += 1

            c_branching = {k: 0 for k in range(1, self.max_depth + 1)}
            r_branching = {k: 0 for k in range(1, self.max_depth + 1)}
            for reac_key in reactions:
                reac_smiles, depth1 = reac_key
                c_branching[depth1] += len(reac_smiles.split("."))
                r_branching[depth1] += 1

            if target_cost == float('inf'):
                for key, C in chemicals.items():
                    print key, C.purchase_price, C.cost, C.retro_results
            #print " ------------------------------------------------- "

            # Save details for chemicals ...
            self.save_pathway(self.pathways[_id], target_cost, smiles_id,
                              [c_branching, r_branching], buyable)
        except:
            print "Error in update_tree:", traceback.format_exc()

    def save_pathway(self, pathway, target_cost, target_id, branching,
                     buyable):
        #if self.fileName:
        #with open("train/pathways/" + self.fileName + ".pkl", "a+b") as fid:
        #	pickle.dump(pathway, fid, pickle.HIGHEST_PROTOCOL)
        #with open(self.fileName, "a+") as fid:
        #	fid.write("{} {} {}\n".format(target_id,target_cost,int(buyable)))

        c_branching, r_branching = branching
        c_branching = [
            str(c_branching[k]) for k in range(1, self.max_depth + 1)
        ]
        r_branching = [
            str(r_branching[k]) for k in range(1, self.max_depth + 1)
        ]
        branching = c_branching + r_branching
        branching = " ".join(branching)
        print_out = "{} {} {} {}\n".format(target_id, target_cost,
                                           int(buyable), branching)

        with open(self.fileName, "a+") as fid:
            fid.write(print_out)

    def coordinate(self):
        try:
            start_time = time.time()
            elapsed_time = time.time() - start_time
            next = 1
            finished = False
            while (elapsed_time <
                   self.expansion_time) and self.waiting_for_results():

                if (int(elapsed_time) / 10 == next):
                    next += 1
                    print "Worked for {}/{} s".format(
                        int(elapsed_time * 10) / 10.0, self.expansion_time)
                    print "... attempts {}\n... pathways {}".format(
                        self.pathway_count, self.successful_pathway_count)

                try:
                    for (_id, chem_smi, depth,
                         precursors) in self.get_ready_result():
                        children = self.add_reactants(_id, chem_smi, depth,
                                                      precursors)
                        self.active_chemicals[_id].remove((chem_smi, depth))
                        if bool(children):
                            if children == 'cyclic' or children == 'unexpandable':
                                continue
                            if (len(children) + self.pathway_status[_id][0] <=
                                    self.pathway_status[_id][2]):
                                for kid in children:
                                    self.active_chemicals[_id].append(kid)
                                _expand = self.expand_products(
                                    _id, children, self.rollout_branching)
                                continue
                        self.pathway_status[_id][1] = False

                    for _id in range(self.nproc):
                        no_worker = bool(self.idle[_id])
                        is_pathway = bool(self.pathways[_id])
                        no_results = self.results_queues[_id].empty()
                        no_expansions = self.expansion_queues[_id].empty()
                        is_pathway_dead = (not self.pathway_status[_id][1])
                        check_dead = all([
                            no_worker, is_pathway, no_results, no_expansions,
                            is_pathway_dead
                        ])

                        if check_dead:
                            processed = [
                                chem_dict.processed for chem_dict in
                                self.pathways[_id]['chemical_nodes'].values()
                            ]
                            if all(processed):
                                self.update_tree(_id)
                                self.pathways[_id] = 0
                                self.active_chemicals[_id] = []
                                self.pathways_queue.put(_id)
                                #print "... put pathway (1) into pathways queue ... "
                            elif (self.pathway_status[0] >=
                                  self.total_applied_templates) and (
                                      not self.active_chemicals[_id]):
                                self.update_tree(_id)
                                self.pathways[_id] = 0
                                self.active_chemicals[_id] = []
                                self.pathways_queue.put(_id)
                                #print "... put pathway (2) into pathways queue ... "
                            else:
                                pass

                        else:
                            is_pathway = bool(self.pathways[_id])
                            if is_pathway:
                                processed = [
                                    chem_dict.processed
                                    for chem_dict in self.pathways[_id]
                                    ['chemical_nodes'].values()
                                ]
                                no_results = self.results_queues[_id].empty()
                                no_expansions = self.expansion_queues[
                                    _id].empty()
                                active_chemicals = (
                                    not self.active_chemicals[_id])
                                check_delayed = all([
                                    no_results, no_expansions, active_chemicals
                                ])
                                if check_delayed and processed:
                                    if all(processed):
                                        self.update_tree(_id)
                                        self.pathways[_id] = 0
                                        self.active_chemicals[_id] = []
                                        self.pathways_queue.put(_id)
                                        #print "... put pathway (3) into pathways queue ... "

                    if finished:
                        if all([(len(self.active_chemicals[_id]) == 0)
                                for _id in range(self.nproc)]):
                            break
                        continue

                    for _id in self.get_pathway_result():
                        try:
                            pair = self.smiles_generator.next()
                            smiles_id, smiles = pair
                        except StopIteration:
                            print "We are finished!"
                            finished = True
                            break
                        leaves = [(smiles, 0)]
                        pathway = {
                            'chemicals': set(),
                            'chemical_nodes': {},
                            'reaction_nodes': {},
                            'target': smiles,
                            'smiles_id': smiles_id
                        }
                        self.pathways[_id] = pathway
                        self.pathway_status[_id] = [
                            0, True, self.total_applied_templates
                        ]
                        self.set_initial_target(_id, leaves)
                        elapsed_time = time.time() - start_time

                except Exception as E:
                    print "... unspecified ERROR:", traceback.format_exc()
                    elapsed_time = time.time() - start_time

            self.stop()
            print "... exited prematurely."

        except:
            print "Error in coordinate:", traceback.format_exc()
            sys.exit(1)

    def work(self, i):

        use_mincost = False
        prioritizers = (self.precursor_prioritization,
                        self.template_prioritization)

        if self.precursor_prioritization == gc.mincost:
            print "Loading model weights train/fit/{}/".format(
                self.policy_iteration)
            from makeit.prioritization.precursors.mincost import MinCostPrecursorPrioritizer
            model = MinCostPrecursorPrioritizer()
            model.load_model(
                datapath='train/fit/{}/'.format(self.policy_iteration))
            prioritizers = (gc.relevance_precursor,
                            self.template_prioritization)
            use_mincost = True

        while True:
            # If done, stop
            if self.done.value:
                print 'Worker {} saw done signal, terminating'.format(i)
                #MyLogger.print_and_log(
                #	'Worker {} saw done signal, terminating'.format(i), treebuilder_loc)
                break

            # If paused, wait and check again
            if self.paused.value:
                time.sleep(1)
                continue

            # Grab something off the queue
            try:
                self.idle[i] = False
                (jj, smiles, depth, branching) = self.expansion_queues[i].get(
                    timeout=0.25)  # short timeout
                #prioritizers = (self.precursor_prioritization, self.template_prioritization)
                outcomes = self.retroTransformer.get_outcomes(
                    smiles,
                    self.mincount,
                    prioritizers,
                    depth=depth,
                    template_count=self.template_count,
                    mode=self.precursor_score_mode,
                    max_cum_prob=self.max_cum_template_prob)
                if use_mincost:
                    for precursor in outcomes.precursors:
                        precursor.retroscore = 1.0 + sum([
                            abs(model.get_score_from_smiles(smile, depth + 1))
                            for smile in precursor.smiles_list
                        ])
                        #print smiles, precursor.retroscore, precursor.smiles_list

                reaction_precursors = outcomes.return_top(
                    n=self.rollout_branching)

                # Epsilon-greedy:
                if (random.random() <
                        self.epsilon) and len(reaction_precursors) > 0:
                    reaction_precursors = [random.choice(reaction_precursors)]
                self.results_queues[jj].put(
                    (jj, smiles, depth, reaction_precursors))

            except VanillaQueue.Empty:
                pass

            except Exception as e:
                print traceback.format_exc()

            time.sleep(0.01)
            self.idle[i] = True

    def add_reactants(self, _id, chem_smi, depth, precursors):
        try:
            self.pathways[_id]['chemical_nodes'][(chem_smi,
                                                  depth)].processed = True
            # If no templates applied, do not go further, chemical not makeable.
            if not precursors:
                self.pathways[_id]['chemical_nodes'][chem_smi,
                                                     depth].retro_results = []
                return 'unexpandable'
                #return False

            scores_list = []
            for result in precursors:
                reactants = result['smiles_split']
                retroscore = result['score']
                template_action = result['tforms']
                template_probability = result['template_score']

                # Reject cyclic templates as 'illegal moves'.
                cyclic_template = False
                for q, smi in enumerate(reactants):
                    if smi in self.pathways[_id]['chemicals']:
                        reactant_smile_key = sorted([(rchem_smi, rdepth) for (
                            rchem_smi,
                            rdepth) in self.pathways[_id]['chemical_nodes']
                                                     if (rchem_smi == smi)],
                                                    key=lambda x: x[1])[0]
                        if not (self.pathways[_id]['chemical_nodes']
                                [reactant_smile_key].purchase_price >= 0):
                            last_reactant_cost = self.pathways[_id][
                                'chemical_nodes'][reactant_smile_key].cost
                            if (last_reactant_cost >= self.max_penalty) or (
                                    last_reactant_cost == -1):
                                cyclic_template = True
                                break
                if cyclic_template:
                    continue
                scores_list.append([
                    retroscore, reactants, template_probability,
                    template_action
                ])

            if not scores_list:
                self.pathways[_id]['chemical_nodes'][chem_smi,
                                                     depth].retro_results = []
                return 'unexpandable'

            results = sorted(scores_list,
                             key=lambda x:
                             (x[0], sum([len(xx) for xx in x[1]])))
            self.pathways[_id]['chemical_nodes'][
                chem_smi, depth].retro_results = results  #precursors

            #for result in results:
            #	print chem_smi, depth, result[0], result[1]

            for p, result in enumerate(results):
                react_cost, reactants, template_prob, template_no = result
                if isinstance(reactants, list):
                    rxn_smi = ".".join(reactants)
                else:
                    rxn_smi = reactants

                if p == 0:
                    children = []
                    self.pathways[_id]['chemical_nodes'][(
                        chem_smi, depth)].add_incoming_reaction(
                            rxn_smi, (template_no, template_prob))
                    self.pathways[_id]['chemical_nodes'][(
                        chem_smi, depth)].retro_results = result
                    self.pathways[_id]['reaction_nodes'][(rxn_smi, depth +
                                                          1)] = Reaction(
                                                              rxn_smi,
                                                              depth + 1)
                    self.pathways[_id]['reaction_nodes'][(
                        rxn_smi, depth + 1)].add_outgoing_chemical(
                            chem_smi, (template_no, template_prob))

                    for q, smi in enumerate(reactants):
                        if (smi, depth + 1) in children: continue
                        children.append((smi, depth + 1))
                        self.pathways[_id]['reaction_nodes'][(
                            rxn_smi, depth + 1)].add_incoming_chemical(smi)
                        if (smi, depth +
                                1) not in self.pathways[_id]['chemical_nodes']:
                            self.pathways[_id]['chemical_nodes'][smi, depth +
                                                                 1] = Chemical(
                                                                     smi,
                                                                     depth + 1)
                            ppg = self.get_price(smi)
                            self.pathways[_id]['chemical_nodes'][smi, depth +
                                                                 1].price(ppg)
                            if (ppg >= 0.0) or (
                                (depth + 1) == self.max_depth) or (
                                    not self.pathway_status[_id][1]):
                                self.pathways[_id]['chemical_nodes'][
                                    smi, depth + 1].processed = True
                            if ((depth + 1)
                                    == self.max_depth) and (not ppg >= 0.0):
                                self.pathway_status[_id][1] = False
                break

            #################################
            if children and (
                    depth < self.max_depth
            ) and self.pathway_status[_id][1]:  #not cyclic_template
                return children
            else:
                #print "Warning (ii): Nothing left to expand.", cyclic_template
                for smi in reactants:
                    try:
                        if (not self.pathways[_id]['chemical_nodes'][
                                smi, depth + 1].processed):
                            self.pathways[_id]['chemical_nodes'][
                                smi, depth + 1].processed = True
                    except:
                        pass
                return False

        except Exception as E:
            print "Error in add_reactants:", traceback.format_exc()
            print self.pathways[_id]['chemicals']
            for key, c in self.pathways[_id]['chemical_nodes'].items():
                print key, type(c.retro_results)

    def expand_products(self, _id, children, branching):
        try:
            synthetic_expansion_candidates = 0
            for (chem_smi, depth) in children:
                if depth >= self.max_depth:
                    self.active_chemicals[_id].remove((chem_smi, depth))
                    self.pathway_status[_id][1] = False
                    continue

                ppg = self.get_price(
                    chem_smi)  #self.Chemicals[chem_smi,depth].purchase_price
                if chem_smi in self.pathways[_id]['chemicals']:
                    if not (ppg >= 0.0):
                        self.active_chemicals[_id].remove((chem_smi, depth))
                        self.pathway_status[_id][1] = False
                        continue

                self.pathways[_id]['chemicals'].add(chem_smi)
                if (chem_smi,
                        depth) not in self.pathways[_id]['chemical_nodes']:
                    self.pathways[_id]['chemical_nodes'][chem_smi,
                                                         depth] = Chemical(
                                                             chem_smi, depth)
                self.pathways[_id]['chemical_nodes'][chem_smi,
                                                     depth].price(ppg)

                if ppg >= 0:
                    self.pathways[_id]['chemical_nodes'][(
                        chem_smi, depth)].processed = True
                    self.active_chemicals[_id].remove((chem_smi, depth))
                    continue

                if not (self.pathway_status[_id][0] <
                        self.pathway_status[_id][2]):
                    self.pathways[_id]['chemical_nodes'][(
                        chem_smi, depth)].processed = True
                    self.active_chemicals[_id].remove((chem_smi, depth))
                    self.pathway_status[_id][1] = False
                    continue

                synthetic_expansion_candidates += 1
                self.pathway_status[_id][0] += 1

                # TO-DO
                # If we have already expanded the node, don't re-do it.
                # Form for results_queue.put(): (jj, smiles, depth, precursors, pathway)
                self.expand(_id, chem_smi, depth, branching)

            return synthetic_expansion_candidates

        except Exception as e:
            print "Error in expand_products:", traceback.format_exc()

    def target_generator(self):
        return self.target_generator_func()

    def target_generator_func(self):
        for data in self.target_chemicals:
            yield data

    def build_tree(self):
        start_time = time.time()
        self.running = True
        self.prepare()

        self.smiles_generator = self.target_generator()
        for k in range(self.nproc):
            try:
                pair = self.smiles_generator.next()
                smiles_id, smiles = pair
                #self.epsilon = epsilon
            except StopIteration:
                print "(a) We are finished!"
                break
            leaves = [(smiles, 0)]
            pathway = {
                'chemicals': set(),
                'chemical_nodes': {},
                'reaction_nodes': {},
                'target': smiles,
                'smiles_id': smiles_id
            }
            self.pathways[k] = pathway
            self.pathway_status[k] = [0, True, self.total_applied_templates]
            self.set_initial_target(k, leaves)

        # Coordinate workers.
        self.coordinate()
        '''
		# Save CRN
		mincost, num_pathways = self.save_crn()

		# Save states for training value network 
		training_states_save = "states/replica_{}.pickle".format(self.replica)
		value_network_training_states(self.smiles_id, 
										self.Chemicals, 
										self.Reactions, 
										FP_rad = 3, 
										FPS_size = 16384, 
										fileName = training_states_save)
		'''
        print "Finished working."

    def reset(self):
        if self.celery:
            # general parameters in celery format
            pass
        else:
            self.manager = Manager()
            self.done = self.manager.Value('i', 0)
            self.paused = self.manager.Value('i', 0)
            self.idle = self.manager.list()
            self.results_queue = Queue()
            self.workers = []
            self.coordinator = None
            self.running = False

            ## Queues
            self.pathways = [0 for i in range(self.nproc)]
            self.pathways_queue = Queue()
            self.pathway_status = [[0, True, self.total_applied_templates]
                                   for i in range(self.nproc)]
            self.sampled_pathways = []
            self.pathway_count = 0
            self.successful_pathway_count = 0

            for i in range(self.nproc):
                self.idle.append(True)
            if self.nproc != 1:
                self.expansion_queues = [Queue() for i in range(self.nproc)]
                self.results_queues = [Queue() for i in range(self.nproc)]
            else:
                self.expansion_queues = [Queue()]
                self.results_queues = [Queue()]
        self.active_chemicals = [[] for x in range(self.nproc)]

    def get_buyable_paths(self,
                          target_chemicals,
                          replica=0,
                          fileName=None,
                          max_depth=10,
                          expansion_time=300,
                          expansion_branching=1,
                          rollout_branching=1,
                          total_applied_templates=1000,
                          noise_std_percentage=None,
                          template_prioritization=gc.relevance,
                          precursor_prioritization=gc.heuristic,
                          policy_iteration=None,
                          nproc=8,
                          mincount=25,
                          chiral=True,
                          epsilon=0.0,
                          template_count=50,
                          precursor_score_mode=gc.max,
                          max_cum_template_prob=0.995):
        self.target_chemicals = target_chemicals
        self.replica = replica
        self.fileName = fileName
        self.mincount = mincount
        self.max_depth = max_depth
        self.expansion_branching = expansion_branching
        self.expansion_time = expansion_time
        self.rollout_branching = rollout_branching
        self.total_applied_templates = total_applied_templates
        self.template_prioritization = template_prioritization
        self.precursor_prioritization = precursor_prioritization
        self.precursor_score_mode = precursor_score_mode
        self.nproc = nproc
        self.template_count = template_count
        self.max_cum_template_prob = max_cum_template_prob
        self.epsilon = epsilon
        if self.precursor_prioritization == 'random':
            self.epsilon = 1.0
        self.noise_std_percentage = noise_std_percentage
        self.policy_iteration = policy_iteration

        self.depth_penalty = score_max_depth()
        self.max_penalty = score_no_templates()

        self.manager = Manager()
        # specificly for python multiprocessing
        self.done = self.manager.Value('i', 0)
        self.paused = self.manager.Value('i', 0)
        # Keep track of idle workers
        self.idle = self.manager.list()
        self.results_queue = Queue()
        self.workers = []
        self.coordinator = None
        self.running = False

        ## Queues
        self.pathways = [0 for i in range(self.nproc)]
        self.pathways_queue = Queue()
        self.pathway_status = [[0, True, self.total_applied_templates]
                               for i in range(self.nproc)]
        self.sampled_pathways = []
        self.pathway_count = 0
        self.successful_pathway_count = 0

        if not self.celery:
            for i in range(nproc):
                self.idle.append(True)
            if self.nproc != 1:
                self.expansion_queues = [Queue() for i in range(self.nproc)]
                self.results_queues = [Queue() for i in range(self.nproc)]
            else:
                self.expansion_queues = [Queue()]
                self.results_queues = [Queue()]
        self.active_chemicals = [[] for x in range(nproc)]

        #print "Starting search for id:", smiles_id, "smiles:", smiles
        return self.build_tree()
Esempio n. 11
0
class SCScorePrecursorPrioritizer(Prioritizer):
    """Standalone, importable SCScorecorer model.

    It does not have tensorflow as a dependency and is a more attractive option
    for deployment. The calculations are fast enough that there is no real
    reason to use GPUs (via tf) instead of CPUs (via np).

    Attributes:
        vars (list of np.ndarry of np.ndarray of np.float32): Weights and bias
            of given model.
        FP_rad (int): Fingerprint radius.
        FP_len (int): Fingerprint length.
        score_scale (float): Upper-bound of scale for scoring.
        pricer (Pricer or None): Pricer instance to lookup chemical costs.
    """
    def __init__(self, score_scale=5.0):
        """Initializes SCScorePrecursorPrioritizer.

        Args:
            score_scale (float, optional): Upper-bound of scale for scoring.
                (default: {5.0})
        """
        self.vars = []
        self.FP_rad = 2
        self.score_scale = score_scale
        self._restored = False
        self.pricer = None
        self._loaded = False

    def load_model(self, FP_len=1024, model_tag='1024bool'):
        """Loads model from given tag.

        Args:
            FP_len (int, optional): Fingerprint length. (default: {1024})
            model_tag (str, optional): Tag of model to load.
                (default: {'1024bool'})
        """
        self.FP_len = FP_len
        if model_tag != '1024bool' and model_tag != '1024uint8' and model_tag != '2048bool':
            MyLogger.print_and_log(
                'Non-existent SCScore model requested: {}. Using "1024bool" model'
                .format(model_tag),
                scscore_prioritizer_loc,
                level=2)
            model_tag = '1024bool'
        filename = 'trained_model_path_' + model_tag
        with open(gc.SCScore_Prioritiaztion[filename], 'rb') as fid:
            self.vars = pickle.load(fid)
        if gc.DEBUG:
            MyLogger.print_and_log(
                'Loaded synthetic complexity score prioritization model from {}'
                .format(gc.SCScore_Prioritiaztion[filename]),
                scscore_prioritizer_loc)

        if 'uint8' in gc.SCScore_Prioritiaztion[filename]:

            def mol_to_fp(mol):
                """Returns fingerprint of molecule for uint8 model.

                Args:
                    mol (Chem.rdchem.Mol or None): Molecule to get fingerprint
                        of.

                Returns:
                    np.ndarray of np.uint8: Fingerprint of given molecule.
                """
                if mol is None:
                    return np.array((self.FP_len, ), dtype=np.uint8)
                fp = AllChem.GetMorganFingerprint(
                    mol, self.FP_rad, useChirality=True)  # uitnsparsevect
                fp_folded = np.zeros((self.FP_len, ), dtype=np.uint8)
                for k, v in fp.GetNonzeroElements().items():
                    fp_folded[k % self.FP_len] += v
                return np.array(fp_folded)
        else:

            def mol_to_fp(mol):
                """Returns fingerprint of molecule for bool model.

                Args:
                    mol (Chem.rdchem.Mol or None): Molecule to get fingerprint
                        of.

                Returns:
                    np.ndarray of np.bool or np.float32: Fingerprint of given
                        molecule.
                """
                if mol is None:
                    return np.zeros((self.FP_len, ), dtype=np.float32)
                return np.array(AllChem.GetMorganFingerprintAsBitVect(
                    mol, self.FP_rad, nBits=self.FP_len, useChirality=True),
                                dtype=np.bool)

        self.mol_to_fp = mol_to_fp

        self.pricer = Pricer()
        self.pricer.load()
        self._restored = True
        self._loaded = True

    def smi_to_fp(self, smi):
        """Returns fingerprint of molecule from given SMILES string.

        Args:
            smi (str): SMILES string of given molecule.
        """
        if not smi:
            return np.zeros((self.FP_len, ), dtype=np.float32)
        return self.mol_to_fp(Chem.MolFromSmiles(str(smi)))

    def apply(self, x):
        """Applies model to a fingerprint to calculate score.

        Args:
            x (np.ndarray): Fingerprint of molecule to apply model to.

        Returns:
            float: Score of molecule.
        """
        if not self._restored:
            raise ValueError('Must restore model weights!')
        # Each pair of vars is a weight and bias term
        for i in range(0, len(self.vars), 2):
            last_layer = (i == (len(self.vars) - 2))
            W = self.vars[i]
            b = self.vars[i + 1]
            x = np.dot(W.T, x) + b
            if not last_layer:
                x = x * (x > 0)  # ReLU
        x = 1 + (self.score_scale - 1) * sigmoid(x)
        return x

    def get_priority(self, retroProduct, **kwargs):
        """Returns priority of given product.

        Args:
            retroProduct (str or RetroPrecursor): Product to calculate score
                for.
            **kwargs: Additional optional arguments. Used for mode.

        Returns:
            float: Priority of given product.
        """
        mode = kwargs.get('mode', gc.max)
        if not self._loaded:
            self.load_model()

        if not isinstance(retroProduct, str):
            scores = []
            for smiles in retroProduct.smiles_list:
                scores.append(self.get_score_from_smiles(smiles))
            return -self.merge_scores(scores, mode=mode)
        else:
            return -self.get_score_from_smiles(retroProduct)
        if not retroProduct:
            return -inf

    def merge_scores(self, list_of_scores, mode=gc.max):
        """Merges list of scores into a single score based on a given mode.

        Args:
            list_of_scores (list of floats): Scores to be merged.
            mode (str, optional): Function to merge by. (default: {gc.max})

        Returns:
            float: Merged scores.
        """
        if mode == gc.mean:
            return np.mean(list_of_scores)
        elif mode == gc.geometric:
            return np.power(np.prod(list_of_scores), 1.0 / len(list_of_scores))
        elif mode == gc.pow8:
            pow8 = []
            for score in list_of_scores:
                pow8.append(8**score)
            return np.sum(pow8)
        else:
            return np.max(list_of_scores)

    def get_score_from_smiles(self, smiles, noprice=False):
        """Returns score of molecule from given SMILES string.

        Args:
            smiles (str): SMILES string of molecule.
            noprice (bool, optional): Whether to not use the molecules price as
                its score, if available. (default: {False})
        """
        # Check buyable
        if not noprice:
            ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True)
            if ppg:
                return ppg / 100.

        fp = np.array((self.smi_to_fp(smiles)), dtype=np.float32)
        if sum(fp) == 0:
            cur_score = 0.
        else:
            # Run
            cur_score = self.apply(fp)
        return cur_score