Ejemplo n.º 1
0
def get_delta_labels(reactant_mol, product_mol):
    product_atom_idx = {idxfunc(atom) for atom in product_mol.GetAtoms()}
    reactant_atom_idx = get_reactant_atom_idx(get_reactant_mols(reactant_mol),
                                              product_mol)

    edge_deltas = _get_edge_delta_label(reactant_mol, product_mol)
    h_deltas = _get_hydrogen_delta_label(reactant_mol, product_mol)
    c_deltas = _get_charge_delta_label(reactant_mol, product_mol)

    num_atom = reactant_mol.GetNumAtoms()
    octet_sum = np.zeros(num_atom, dtype=np.int)
    for idx in range(num_atom):

        for idx_other in range(num_atom):
            if idx not in product_atom_idx and idx_other not in product_atom_idx:
                edge_deltas[idx, idx_other] = edge_deltas[idx_other, idx] = 0

        if idx not in product_atom_idx and idx in reactant_atom_idx:
            # assume h on break
            h_deltas[idx] = -np.sum(edge_deltas[idx])
            c_deltas[idx] = 0
        elif idx not in product_atom_idx:
            h_deltas[idx] = 0
            c_deltas[idx] = 0

        octet_sum[idx] = np.sum(
            edge_deltas[idx]) + h_deltas[idx] + c_deltas[idx]

    return {
        EDGE_DELTA_KEY: edge_deltas,
        H_DELTA_KEY: h_deltas,
        C_DELTA_KEY: c_deltas,
        OCTET_SUM_KEY: octet_sum
    }
Ejemplo n.º 2
0
    def _find_top_k(delta_pred, top_k):

        reactant_mol, product_mol = get_reactant_product_molecule(
            delta_pred[OUTPUT_REACTION_STR_KEY])
        reactant_atom_idx = get_reactant_atom_idx(
            get_reactant_mols(reactant_mol), product_mol)

        reaction_center_proba = delta_pred[OUTPUT_EDGE_DELTA_KEY]
        idxs = np.unravel_index(
            np.argsort(reaction_center_proba.ravel())[::-1],
            reaction_center_proba.shape)
        top_k_edits = []
        counter = 0
        while len(top_k_edits) < top_k:
            x = idxs[0][counter]
            y = idxs[1][counter]
            z = idxs[2][counter]
            counter += 1
            if z == 3 or x >= y:
                continue
            if x in reactant_atom_idx and y in reactant_atom_idx:
                top_k_edits.append('{}-{}-{}'.format(x + 1, y + 1,
                                                     EDGE_DELTA_VAL_LIST[z]))

        delta_labels = get_delta_labels(reactant_mol, product_mol)
        edge_delta_label = delta_labels[EDGE_DELTA_KEY]
        edit_str = []
        for x, y in zip(*np.nonzero(edge_delta_label)):
            if x >= y:
                continue
            else:
                edit_str.append('{}-{}-{}'.format(x + 1, y + 1,
                                                  edge_delta_label[x, y]))

        return top_k_edits, edit_str
Ejemplo n.º 3
0
def eval_gurobi_sampler(delta_pred_list):
    sample_solution_list = []

    for idx, delta_pred in enumerate(delta_pred_list):
        reaction_str = delta_pred[OUTPUT_REACTION_STR_KEY]
        reactant_mol, product_mol = get_reactant_product_molecule(reaction_str)
        reactant_atom_idx = get_reactant_atom_idx(
            get_reactant_mols(reactant_mol), product_mol)

        time_lapse = perf_counter()
        solutions = run_gurobi_sampler(delta_pred,
                                       num_candidates=max(top_ks),
                                       verbose=args.verbose)
        time_lapse = perf_counter() - time_lapse

        duplications = set()
        alt_solutions = []
        for solution in solutions:
            solution_mol = solution[SAMPLE_SOLUTION_MOL_KEY]
            solution_mol = find_primary_product_using_reactant_idx(
                Chem.MolToSmiles(solution_mol), reactant_atom_idx)
            for atom in solution_mol.GetAtoms():
                atom.SetAtomMapNum(0)
            solution_smi = Chem.MolToSmiles(solution_mol)
            if solution_smi in duplications:
                continue
            else:
                duplications.add(solution_smi)
                alt_solutions.append(solution)
        solutions = alt_solutions

        found_in_count = -1
        solution_strs = []
        solution_obj_vals = []
        for solution_idx, solution in enumerate(solutions):
            solution_mol = solution[SAMPLE_SOLUTION_MOL_KEY]
            solution_strs.append(Chem.MolToSmiles(solution_mol))
            solution_obj_vals.append(solution[SAMPLE_SOLUTION_VAL_KEY])
            if found_in_count < 0 and is_same_molecule(product_mol,
                                                       solution_mol):
                found_in_count = solution_idx

        delta_pred[SAMPLE_SOLUTION_MOL_KEY] = solution_strs
        delta_pred[SAMPLE_SOLUTION_VAL_KEY] = solution_obj_vals

        RESULT[GUROBI_KEY][COUNT_KEY].append(found_in_count)
        RESULT[GUROBI_KEY][TIME_KEY].append(time_lapse)

        sample_solution_list.append(delta_pred)

        if args.verbose:
            print('#{}: {}'.format(idx, found_in_count))

        if idx % args.log_interval == 0:
            log_result(idx)

    log_result(len(sample_solution_list))
Ejemplo n.º 4
0
    def __init__(self, delta_pred,
                 num_candidates=10,
                 calibration=(EDGE_CALIBRATION_KEY),
                 soften=True,
                 octet_rule=True,
                 verbose=False):
        self.reaction_str = delta_pred[OUTPUT_REACTION_STR_KEY]
        self.edge_delta_pred = delta_pred[OUTPUT_EDGE_DELTA_KEY]
        self.c_delta_pred = delta_pred[OUTPUT_C_DELTA_KEY]
        self.h_delta_pred = delta_pred[OUTPUT_H_DELTA_KEY]
        if soften:
            self.edge_delta_pred = soften_matrix(self.edge_delta_pred)
            self.c_delta_pred = soften_matrix(self.c_delta_pred)
            self.h_delta_pred = soften_matrix(self.h_delta_pred)
        self.num_candidates = num_candidates
        self.edge_coefficient = 5.0
        self.h_coefficient = 1.0
        self.c_coefficient = 1.0
        self.octet_rule = octet_rule
        self.edge_calibration_fn = no_calibration
        self.h_calibration_fn = no_calibration
        if EDGE_CALIBRATION_KEY in calibration:
            self.edge_calibration_fn = smooth_calibrate
        if H_CALIBRATION_KEY in calibration:
            self.h_calibration_fn = smooth_calibrate
        self.verbose = verbose

        self.reactant_mol, self.product_mol = get_reactant_product_molecule(self.reaction_str)
        Chem.SanitizeMol(self.reactant_mol)
        Chem.Kekulize(self.reactant_mol, clearAromaticFlags=True)

        self.n_atom = self.reactant_mol.GetNumAtoms()

        self.reactant_atom_map = {idxfunc(atom): atom for atom in self.reactant_mol.GetAtoms()}
        self.reactant_bond_map = {bond_idx_tuple(bond): bond for bond in self.reactant_mol.GetBonds()}

        self.reactant_atom_idx = get_reactant_atom_idx(get_reactant_mols(self.reactant_mol), self.product_mol)

        self.idx_to_delta_vars = {}
        self.reaction_center_delta_vars = []
        self.model_objective = []

        self.model = Model('Gurobi Sampler for Octet Sampling')
        if not self.verbose:
            self.model.setParam(GRB.Param.OutputFlag, 0)

        self._set_variables()

        self._set_constraints()

        self._set_model_objective()

        self._set_model_param()

        self._optimize_model()
Ejemplo n.º 5
0
def _get_reactivity_prediction_features(reactant_mol, product_mol):
    reactant_mols = get_reactant_mols(reactant_mol)
    reactant_atom_idx = get_reactant_atom_idx(reactant_mols, product_mol)
    reactant_component_map = get_reactant_component_map(reactant_mols)
    atom_features = get_mol_atom_features(reactant_mol,
                                          reactant_atom_idx=reactant_atom_idx)
    bond_features = get_mol_bond_features(
        reactant_mol,
        reactant_atom_idx=reactant_atom_idx,
        reactant_component_map=reactant_component_map)
    return {**atom_features, **bond_features}
Ejemplo n.º 6
0
def get_candidate_ranking_features(reaction_str, candidate_str, candidate_val,
                                   edge_delta_pred, h_delta_pred,
                                   c_delta_pred):
    reactant_mol, product_mol = get_reactant_product_molecule(reaction_str)

    num_atom = reactant_mol.GetNumAtoms()

    reactant_mols = get_reactant_mols(reactant_mol)
    reactant_atom_idx = get_reactant_atom_idx(reactant_mols, product_mol)
    reactant_component_map = get_reactant_component_map(reactant_mols)
    reactant_atom_features = get_mol_atom_features(
        reactant_mol, reactant_atom_idx=reactant_atom_idx)
    reactant_bond_features = get_mol_bond_features(
        reactant_mol,
        reactant_atom_idx=reactant_atom_idx,
        reactant_component_map=reactant_component_map)

    candidate_mol = Chem.MolFromSmiles(candidate_str)
    candidate_atom_features = get_mol_atom_features(candidate_mol,
                                                    num_atom=num_atom)
    candidate_bond_features = get_mol_bond_features(candidate_mol,
                                                    num_atom=num_atom)

    candidate_atom_features[ATOM_FEATURES_KEY] = np.concatenate(
        (candidate_atom_features[ATOM_FEATURES_KEY],
         reactant_atom_features[ATOM_FEATURES_KEY], h_delta_pred,
         c_delta_pred),
        axis=-1)
    candidate_bond_features[BOND_FEATURES_KEY] = np.concatenate(
        (candidate_bond_features[BOND_FEATURES_KEY],
         reactant_bond_features[BOND_FEATURES_KEY], edge_delta_pred),
        axis=-1)

    candidate_atom_features[ATOM_FEATURES_KEY] = _pad_features_with_val(
        candidate_atom_features[ATOM_FEATURES_KEY], candidate_val)
    candidate_bond_features[BOND_FEATURES_KEY] = _pad_features_with_val(
        candidate_bond_features[BOND_FEATURES_KEY], candidate_val)

    return {**candidate_atom_features, **candidate_bond_features}