Exemple #1
0
def get_reaction_core_atoms(rsmiles):
    """ Returns the indices of atoms that participate in the reaction for each molecule in the reaction. If the molecule
        does not contain such atoms, return an empty list.
        NOTE: This method is based on the assumption that the reaction mapping is correct and done by matching the same
        atoms in the reactants and products. """

    reactants, _, products = parse_reaction_roles(rsmiles, as_what="mol")
    reactants_final = [set() for _ in range(len(reactants))]
    products_final = [set() for _ in range(len(products))]

    for p_ind, product in enumerate(products):
        for r_ind, reactant in enumerate(reactants):
            for p_atom in product.GetAtoms():
                if p_atom.GetAtomMapNum() <= 0:
                    products_final[p_ind].add(p_atom.GetIdx())
                    continue
                for r_atom in reactant.GetAtoms():
                    if molecule_is_mapped(
                            reactant) and r_atom.GetAtomMapNum() <= 0:
                        reactants_final[r_ind].add(r_atom.GetIdx())
                        continue
                    if p_atom.GetAtomMapNum() == r_atom.GetAtomMapNum():
                        if not same_neighbourhood_size(p_atom.GetIdx(), product, r_atom.GetIdx(), reactant) or \
                                not same_neighbour_atoms(p_atom.GetIdx(), product, r_atom.GetIdx(), reactant) or \
                                not same_neighbour_bonds(p_atom.GetIdx(), product, r_atom.GetIdx(), reactant):
                            reactants_final[r_ind].add(r_atom.GetIdx())
                            products_final[p_ind].add(p_atom.GetIdx())

    return reactants_final, products_final
Exemple #2
0
def extract_info_from_reaction(reaction_smiles, reaction_cores=None):
    """ Extract the reactive and non-reactive parts of the reactant and product molecules from the reaction. """

    reactant_fragments, product_fragments = [], []

    # Extract the reactants and products as RDKit Mol objects and find the reaction cores if none are specified.
    reactants, _, products = parse_reaction_roles(reaction_smiles,
                                                  as_what="mol_no_maps")

    if reaction_cores is None:
        reaction_cores = get_reaction_core_atoms(reaction_smiles)

    # Extraction of information from the reactant molecules.
    for r_ind, reactant in enumerate(reactants):
        # Sanitize the focus molecule.
        AllChem.SanitizeMol(reactant)
        # Sort the core atom indices in descending order to avoid removal conflicts.
        reactive_atoms = sorted(reaction_cores[0][r_ind], reverse=True)

        # Mark and remove all of the atoms which are not in the reaction core.
        rw_mol, basic_rw_mol = extract_core_from_mol(reactant, reactive_atoms)

        # Clean and convert the extracted core candidates to different data formats.
        reactive_part = generate_fragment_data(rw_mol,
                                               reaction_side="reactant",
                                               basic_editable_mol=basic_rw_mol)

        # Mark and remove all of the atoms from the reaction core.
        rw_mol, basic_rw_mol = extract_synthons_from_reactant(
            reactant, reactive_atoms)

        # Clean and convert the extracted core candidates to different data formats.
        non_reactive_part = generate_fragment_data(
            rw_mol, reaction_side="reactant", basic_editable_mol=basic_rw_mol)

        reactant_fragments.append((reactive_part, non_reactive_part))

    # Extraction of information from the product molecules.
    for p_ind, product in enumerate(products):
        # Sanitize the focus molecule.
        AllChem.SanitizeMol(product)
        # Sort the core atom indices in DESC order to avoid removal conflicts.
        reactive_atoms = sorted(reaction_cores[1][p_ind], reverse=True)

        # Mark and remove all of the atoms which are not in the reaction core.
        rw_mol, _ = extract_core_from_mol(product, reactive_atoms)

        # Clean and convert the extracted core candidates to different data formats.
        reactive_part = generate_fragment_data(rw_mol)

        # Mark and remove all of the atoms from the reaction core.
        rw_mol = extract_synthons_from_product(product, reactive_atoms)

        # Clean and convert the extracted synthon candidates to different data formats.
        non_reactive_part = generate_fragment_data(rw_mol)

        product_fragments.append((reactive_part, non_reactive_part))

    # Return all of the generated data for a single chemical reaction.
    return reactant_fragments, product_fragments
Exemple #3
0
def get_non_reaction_core_atoms(rsmiles, cores):
    """ Returns the atoms of the molecule which are not included in the specified reaction cores. """

    reactants, _, products = parse_reaction_roles(rsmiles,
                                                  as_what="mol_no_maps")
    roles = [reactants, products]
    reverse_cores = ([], [])

    for role_ind, role in enumerate(roles):
        for mol_ind, mol in enumerate(role):
            local_reverse = set()
            for atom in mol.GetAtoms():
                if atom.GetIdx() not in cores[role_ind][mol_ind]:
                    local_reverse.add(atom.GetIdx())
            reverse_cores[role_ind].append(local_reverse)

    return reverse_cores
Exemple #4
0
def get_separated_cores(rsmiles, cores):
    """ Returns the separated cores among the core atoms marked by the mapping. """

    reactants, _, products = parse_reaction_roles(rsmiles, as_what="mol")
    roles = [reactants, products]
    role_connections, connected_atoms, num_atoms = [[], []], [[], []], [[], []]

    for c_ind, core in enumerate(cores):
        for r_ind, role in enumerate(core):
            connections = []
            for ind1, atom1 in enumerate(role):
                for ind2, atom2 in enumerate(role):
                    if ind1 != ind2:
                        if roles[c_ind][r_ind].GetBondBetweenAtoms(
                                atom1, atom2) is not None:
                            if [atom1, atom2] not in connections and [
                                    atom2, atom1
                            ] not in connections:
                                connections.append([atom1, atom2])
            role_connections[c_ind].append(connections)

    for r_ind, role in enumerate(role_connections):
        [connected_atoms[r_ind].append(list(merge_common(r))) for r in role]
        [num_atoms[r_ind].append(len(ca)) for ca in connected_atoms[r_ind]]

    for c_ind, core in enumerate(cores):
        for r_ind, role in enumerate(core):
            for atom in role:
                if not atom_in_core(atom, connected_atoms[c_ind][r_ind]):
                    num_atoms[c_ind][r_ind] += 1

    final_separated_cores = deepcopy(connected_atoms)

    for c_ind, core in enumerate(cores):
        for r_ind, role in enumerate(core):
            for atom in role:
                if not atom_in_core(atom, connected_atoms[c_ind][r_ind]):
                    final_separated_cores[c_ind][r_ind].append([atom])

    return final_separated_cores
def create_final_evaluation_dataset(args):
    """ Creates a version of the test dataset where the non-reactive substructures are not filtered out and the
        compounds are treated like real unknown input compounds without mapping or known reaction class. """

    # Read the test dataset from the specified fold.
    test_dataset = pd.read_pickle(
        args.dataset_config.output_folder +
        "fold_{}/test_data.pkl".format(args.evaluation_config.best_fold))
    final_data_tuples = []

    # Iterate through the test dataset and generate the necessary data.
    for row_ind, row in tqdm(
            test_dataset.iterrows(),
            total=len(test_dataset.index),
            ascii=True,
            desc="Generating the non-filtered version of the test dataset"):
        # Select only the products from the reaction SMILES.
        _, _, products = parse_reaction_roles(row["reaction_smiles"],
                                              as_what="mol_no_maps")

        # Get reaction cores of the reaction for better evaluation.
        products_reaction_cores = get_reaction_core_atoms(
            row["reaction_smiles"])[1]

        # Iterate through all of the product molecules and generate descriptors for each bond.
        for p_ind, product in enumerate(products):
            for bond in product.GetBonds():
                # Specify the current bond atoms and their extended neighbourhood.
                bond_atoms = {bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()}
                ext_bond_atoms = get_atom_environment(bond_atoms, product)

                if args.evaluation_config.best_input_config["type"] == "ecfp":
                    bond_fp = construct_ecfp(
                        product,
                        radius=args.evaluation_config.
                        best_input_config["radius"],
                        bits=args.evaluation_config.best_input_config["bits"],
                        from_atoms=bond_atoms,
                        output_type="np_array",
                        as_type="np_float")

                    ext_bond_fp = construct_ecfp(
                        product,
                        radius=args.evaluation_config.
                        best_input_config["radius"],
                        bits=args.evaluation_config.best_input_config["bits"],
                        from_atoms=ext_bond_atoms,
                        output_type="np_array",
                        as_type="np_float")
                else:
                    bond_fp = construct_hsfp(
                        product,
                        radius=args.evaluation_config.
                        best_input_config["radius"],
                        bits=args.evaluation_config.best_input_config["bits"],
                        from_atoms=bond_atoms,
                        neighbourhood_ext=args.evaluation_config.
                        best_input_config["ext"])

                    ext_bond_fp = construct_hsfp(
                        product,
                        radius=args.evaluation_config.
                        best_input_config["radius"],
                        bits=args.evaluation_config.best_input_config["bits"],
                        from_atoms=ext_bond_atoms,
                        neighbourhood_ext=args.evaluation_config.
                        best_input_config["ext"])

                # If the current bond is part of the core, add that information to the new dataset.
                if bond.GetBeginAtomIdx() in products_reaction_cores[p_ind] or \
                        bond.GetEndAtomIdx() in products_reaction_cores[p_ind]:
                    in_core = True
                else:
                    in_core = False

                # Generate the necessary additional information.
                reactive_part, non_reactive_part = extract_info_from_molecule(
                    product, bond_atoms)
                ext_reactive_part, ext_non_reactive_part = extract_info_from_molecule(
                    product, ext_bond_atoms)

                #reactive_fps = [construct_ecfp(rp_mol, radius=args.descriptor_config.similarity_search["radius"],
                #                               bits=args.descriptor_config.similarity_search["bits"])
                #                for rp_mol in reactive_part[2]]
                #ext_reactive_fps = [construct_ecfp(rp_mol, radius=args.descriptor_config.similarity_search["radius"],
                #                                   bits=args.descriptor_config.similarity_search["bits"])
                #                    for rp_mol in ext_reactive_part[2]]

                non_reactive_fps = [
                    construct_ecfp(
                        nrp_mol,
                        radius=args.descriptor_config.
                        similarity_search["radius"],
                        bits=args.descriptor_config.similarity_search["bits"])
                    for nrp_mol in non_reactive_part[2]
                ]
                ext_non_reactive_fps = [
                    construct_ecfp(
                        nrp_mol,
                        radius=args.descriptor_config.
                        similarity_search["radius"],
                        bits=args.descriptor_config.similarity_search["bits"])
                    for nrp_mol in ext_non_reactive_part[2]
                ]

                final_data_tuples.append((
                    row["patent_id"] + "_{}".format(row_ind),
                    bond.GetIdx(),
                    bond_atoms,
                    bond_fp,
                    ext_bond_atoms,
                    ext_bond_fp,
                    in_core,
                    products_reaction_cores,
                    #reactive_part[0], reactive_part[2], reactive_part[3], reactive_fps,
                    non_reactive_part[0],
                    non_reactive_part[2],
                    non_reactive_part[3],
                    non_reactive_fps,
                    ext_non_reactive_part[0],
                    ext_non_reactive_part[2],
                    ext_non_reactive_part[3],
                    ext_non_reactive_fps,
                    row["reaction_smiles"],
                    row["reaction_class"] if in_core else 0,
                    row["reactants_uq_mol_maps"]))

    # Save the final evaluation dataset as a .pkl file.
    pd.DataFrame(final_data_tuples, columns=["patent_id", "bond_id", "bond_atoms", "bond_fp", "ext_bond_atoms", "ext_bond_fp", "in_core", "reaction_cores",
                                             # "reactive_smiles", "reactive_smols", "reactive_smals", "reactive_fps",
                                             "non_reactive_smiles", "non_reactive_smols", "non_reactive_smals", "non_reactive_fps",
                                             "ext_non_reactive_smiles", "ext_non_reactive_smols", "ext_non_reactive_smals", "ext_non_reactive_fps",
                                             "reaction_smiles", "reaction_class", "reactants_uq_mol_maps"])\
        .to_pickle(args.evaluation_config.final_evaluation_dataset)
def generate_fps_from_reaction_products(reaction_smiles, fp_data_configs):
    """ Generates specified fingerprints for the both reactive and non-reactive substructures of the reactant and
        product molecules that are the participating in the chemical reaction. """

    # Generate the RDKit Mol representations of the product molecules and generate the reaction cores.
    reactants, _, products = parse_reaction_roles(reaction_smiles,
                                                  as_what="mol_no_maps")
    reaction_cores = get_reaction_core_atoms(reaction_smiles)

    # Separate the reaction cores if they consist out of multiple non-neighbouring parts.
    separated_cores = get_separated_cores(reaction_smiles, reaction_cores)

    # Define variables which will be used for storing the results.
    total_reactive_fps, total_non_reactive_fps = [], []

    # Iterate through the product molecules and generate fingerprints for all reactive and non-reactive substructures.
    for p_ind, product in enumerate(products):
        # Iterate through all of the dataset configurations.
        for fp_config in fp_data_configs:
            reactive_fps, non_reactive_fps = [], []
            # Generate fingerprints from the reactive substructures i.e. the reaction core(s).
            for core in separated_cores[1][p_ind]:
                # Generate reactive EC fingerprints and add them to the list.
                if fp_config["type"] == "ecfp":
                    reactive_fps.append(
                        construct_ecfp(product,
                                       radius=fp_config["radius"],
                                       bits=fp_config["bits"],
                                       from_atoms=core,
                                       output_type="np_array",
                                       as_type="np_float"))
                # Generate reactive HS fingerprints and add them to the list.
                else:
                    reactive_fps.append(
                        construct_hsfp(product,
                                       radius=fp_config["radius"],
                                       bits=fp_config["bits"],
                                       from_atoms=core,
                                       neighbourhood_ext=fp_config["ext"]))

            # Generate the extended environment of the reaction core.
            extended_core_env = get_atom_environment(reaction_cores[1][p_ind],
                                                     product,
                                                     degree=1)
            # Generate fingerprints from the non-reactive substructures i.e. non-reaction core substructures.
            for bond in product.GetBonds():
                # Generate the extended environment of the focus bond.
                extended_bond_env = get_bond_environment(bond,
                                                         product,
                                                         degree=1)

                # If the extended environment of the non-reactive substructure does not overlap with the extended
                # reaction core, generate a non-reactive fingerprint representation.
                if not extended_bond_env.intersection(extended_core_env):
                    # Generate non-reactive EC fingerprints and add them to the list.
                    if fp_config["type"] == "ecfp":
                        non_reactive_fps.append(
                            construct_ecfp(product,
                                           radius=fp_config["radius"],
                                           bits=fp_config["bits"],
                                           from_atoms=[
                                               bond.GetBeginAtomIdx(),
                                               bond.GetEndAtomIdx()
                                           ],
                                           output_type="np_array",
                                           as_type="np_float"))
                    # Generate non-reactive HS fingerprints and add them to the list.
                    else:
                        non_reactive_fps.append(
                            construct_hsfp(product,
                                           radius=fp_config["radius"],
                                           bits=fp_config["bits"],
                                           from_atoms=[
                                               bond.GetBeginAtomIdx(),
                                               bond.GetEndAtomIdx()
                                           ],
                                           neighbourhood_ext=fp_config["ext"]))

            # Append the generated fingerprints to the final list.
            total_reactive_fps.append(reactive_fps)
            total_non_reactive_fps.append(non_reactive_fps)

    # Return all of the generated fingerprints and labels.
    return total_reactive_fps, total_non_reactive_fps
def generate_unique_compound_pools(args):
    """ Generates and stores unique (RDKit Canonical SMILES) chemical compound pools of the reactants and products for a
        chemical reaction dataset. The dataset needs to contain a column named 'rxn_smiles' in which the values for the
        mapped reaction SMILES strings are stored. """

    reactant_pool_smiles, product_pool_smiles, reactant_pool_mol, product_pool_mol = [], [], [], []
    reactant_reaction_class, product_reaction_class = [], []

    # Read the raw original chemical reaction dataset.
    raw_dataset = pd.read_csv(args.dataset_config.raw_dataset)

    # Iterate through the chemical reaction entries and generate unique canonical SMILES reactant and product pools.
    # Reagents are skipped in this research.
    for row_ind, row in tqdm(
            raw_dataset.iterrows(),
            total=len(raw_dataset.index),
            desc=
            "Generating unique reactant and product compound representations"):
        # Extract and save the canonical SMILES from the reaction.
        reactants, _, products = parse_reaction_roles(
            row["rxn_smiles"], as_what="canonical_smiles_no_maps")
        [reactant_pool_smiles.append(reactant) for reactant in reactants]
        [product_pool_smiles.append(product) for product in products]

        # Extract and save the RDKit Mol objects from the reaction.
        reactants, _, products = parse_reaction_roles(row["rxn_smiles"],
                                                      as_what="mol_no_maps")
        [reactant_pool_mol.append(reactant) for reactant in reactants]
        [product_pool_mol.append(product) for product in products]

        # Save the reaction class of the entry.
        [reactant_reaction_class.append(row["class"]) for _ in reactants]
        [product_reaction_class.append(row["class"]) for _ in products]

    # Aggregate the saved reaction classes for the same reactant compounds.
    for reactant_ind, reactant in tqdm(
            enumerate(reactant_pool_smiles),
            total=len(reactant_pool_smiles),
            desc="Aggregating reaction class values for the reactant compounds"
    ):
        if type(reactant_reaction_class[reactant_ind]) == set:
            continue

        same_reactant_rows = [
            r_ind for r_ind, r in enumerate(reactant_pool_smiles)
            if r == reactant
        ]
        aggregated_class_values = [
            c for c_ind, c in enumerate(reactant_reaction_class)
            if c_ind in same_reactant_rows
        ]

        for same_row_ind in same_reactant_rows:
            reactant_reaction_class[same_row_ind] = set(
                aggregated_class_values)

    # Aggregate the saved reaction classes for the same product compounds.
    for product_ind, product in tqdm(
            enumerate(product_pool_smiles),
            total=len(product_pool_smiles),
            desc="Aggregating reaction class values for the product compounds"
    ):
        if type(product_reaction_class[product_ind]) == set:
            continue

        same_product_rows = [
            p_ind for p_ind, p in enumerate(product_pool_smiles)
            if p == product
        ]
        aggregated_class_values = [
            c for c_ind, c in enumerate(product_reaction_class)
            if c_ind in same_product_rows
        ]

        for same_row_ind in same_product_rows:
            product_reaction_class[same_row_ind] = set(aggregated_class_values)

    print("Filtering unique reactant and product compounds...", end="")

    # Filter out duplicate reactant molecules from the reactant and product sets.
    reactant_pool_smiles, reactants_uq_ind = np.unique(reactant_pool_smiles,
                                                       return_index=True)
    product_pool_smiles, products_uq_ind = np.unique(product_pool_smiles,
                                                     return_index=True)

    # Apply the unique indices to the list of RDKit Mol objects.
    reactant_pool_mol = np.array(reactant_pool_mol)[reactants_uq_ind].tolist()
    product_pool_mol = np.array(product_pool_mol)[products_uq_ind].tolist()

    # Apply the unique indices to the list of reaction classes.
    reactant_reaction_class = np.array(
        reactant_reaction_class)[reactants_uq_ind].tolist()
    product_reaction_class = np.array(
        product_reaction_class)[products_uq_ind].tolist()

    print("done.")

    # Pre-generate the reactant molecular fingerprint descriptors for similarity searching purpouses.
    ecfp_1024 = []

    for uqr_ind, uq_reactant in tqdm(
            enumerate(reactant_pool_smiles),
            total=len(reactant_pool_smiles),
            desc="Generating reactant compound fingerprints"):
        ecfp_1024.append(
            construct_ecfp(
                uq_reactant,
                radius=args.descriptor_config.similarity_search["radius"],
                bits=args.descriptor_config.similarity_search["bits"]))

    print("Saving the processed reactant compound data...", end="")

    # Store all of the generated reactant fingerprints in a .pkl file.
    pd.DataFrame({"mol_id": list(range(0, len(reactant_pool_smiles))), "canonical_smiles": reactant_pool_smiles,
                  "mol_object": reactant_pool_mol, "ecfp_1024": ecfp_1024, "reaction_class": reactant_reaction_class}).\
        to_pickle(args.dataset_config.output_folder + "unique_reactants_pool.pkl")

    print("done.")

    # Pre-generate the product molecular fingerprint descriptors for similarity searching purpouses.
    ecfp_1024 = []

    for uqp_ind, uq_product in tqdm(
            enumerate(product_pool_smiles),
            total=len(product_pool_smiles),
            desc="Generating product compound fingerprints"):
        ecfp_1024.append(
            construct_ecfp(
                uq_product,
                radius=args.descriptor_config.similarity_search["radius"],
                bits=args.descriptor_config.similarity_search["bits"]))

    print("Saving the processed product compound data...", end="")

    # Store all of the generated product fingerprints in a .pkl file.
    pd.DataFrame({"mol_id": list(range(0, len(product_pool_smiles))), "canonical_smiles": product_pool_smiles,
                  "mol_object": product_pool_mol, "ecfp_1024": ecfp_1024, "reaction_class": product_reaction_class}).\
        to_pickle(args.dataset_config.output_folder + "unique_products_pool.pkl")

    print("done.")
def extract_relevant_information(reaction_smiles, uq_reactant_mols_pool,
                                 uq_product_mols_pool, fp_params):
    """ Extracts the necessary information from a single mapped reaction SMILES string. """

    # Extract the canonical SMILES and RDKit Mol objects from the reaction SMILES string.
    reactant_smiles, _, product_smiles = parse_reaction_roles(
        reaction_smiles, as_what="canonical_smiles_no_maps")
    reactants, _, products = parse_reaction_roles(reaction_smiles,
                                                  as_what="mol_no_maps")

    # Sort the reactants and products in descending order by number of atoms so the largest reactants is always first.
    reactants, reactant_smiles = zip(
        *sorted(zip(reactants, reactant_smiles),
                key=lambda k: len(k[0].GetAtoms()),
                reverse=True))
    products, product_smiles = zip(*sorted(zip(products, product_smiles),
                                           key=lambda k: len(k[0].GetAtoms()),
                                           reverse=True))

    r_uq_mol_maps, rr_smiles, rr_smols, rr_smals, rr_fps, rnr_smiles, rnr_smols, rnr_smals, rnr_fps = \
        [], [], [], [], [], [], [], [], []
    p_uq_mol_maps, pr_smiles, pr_smols, pr_smals, pr_fps, pnr_smiles, pnr_smols, pnr_smals, pnr_fps = \
        [], [], [], [], [], [], [], [], []

    # Extract the reactive and non-reactive parts of the reactant and product molecules.
    reactant_frags, product_frags = extract_info_from_reaction(reaction_smiles)

    # Iterate through all of the reactants and aggregate the specified data.
    for r_ind, reactant in enumerate(reactants):
        r_uq_mol_maps.append(
            uq_reactant_mols_pool.index(reactant_smiles[r_ind]))
        rr_smiles.append(reactant_frags[r_ind][0][0])
        rnr_smiles.append(reactant_frags[r_ind][1][0])
        rr_smols.append(reactant_frags[r_ind][0][2])
        rnr_smols.append(reactant_frags[r_ind][1][2])
        rr_smals.append(reactant_frags[r_ind][0][3])
        rnr_smals.append(reactant_frags[r_ind][1][3])
        rr_fps.append(
            construct_ecfp(reactant_frags[r_ind][0][2],
                           radius=fp_params["radius"],
                           bits=fp_params["bits"]))
        rnr_fps.append(
            construct_ecfp(reactant_frags[r_ind][1][2],
                           radius=fp_params["radius"],
                           bits=fp_params["bits"]))

    # Iterate through all of the products and aggregate the specified data.
    for p_ind, product in enumerate(products):
        p_uq_mol_maps.append(uq_product_mols_pool.index(product_smiles[p_ind]))
        pr_smiles.extend(product_frags[p_ind][0][0])
        pnr_smiles.extend(product_frags[p_ind][1][0])
        pr_smols.extend(product_frags[p_ind][0][2])
        pnr_smols.extend(product_frags[p_ind][1][2])
        pr_smals.extend(product_frags[p_ind][0][3])
        pnr_smals.extend(product_frags[p_ind][1][3])

        for pf in product_frags[p_ind][0][2]:
            pr_fps.append(
                construct_ecfp(pf,
                               radius=fp_params["radius"],
                               bits=fp_params["bits"]))
        for pf in product_frags[p_ind][1][2]:
            pnr_fps.append(
                construct_ecfp(pf,
                               radius=fp_params["radius"],
                               bits=fp_params["bits"]))

    # Return the extracted information.
    return r_uq_mol_maps, rr_smiles, rr_smols, rr_smals, rr_fps, rnr_smiles, rnr_smols, rnr_smals, rnr_fps,\
           p_uq_mol_maps, pr_smiles, pr_smols, pr_smals, pr_fps, pnr_smiles, pnr_smols, pnr_smals, pnr_fps
def draw_reaction(rxn,
                  show_reagents=True,
                  reaction_cores=None,
                  im_size_x=300,
                  im_size_y=200):
    """ Draws the chemical reaction with or without highlighted reaction cores and reactive parts. """

    # Parse the roles from the input object.
    if reaction_cores is None:
        reaction_cores = [[], []]
    if isinstance(rxn, str):
        reactants, reagents, products = parse_reaction_roles(rxn,
                                                             as_what="mol")
    else:
        reactants = rxn.GetReactants()
        products = rxn.GetProducts()
        reagents = []

    mol_images = []

    # Draw images of the reactant molecules and append '+' symbol image after each one, except the last one which needs
    # to be followed by the '->' symbol.
    for r_ind, reactant in enumerate(reactants):
        if len(reaction_cores[0]) > 0:
            mol_images.append(
                draw_molecule(reactant,
                              im_size_x,
                              im_size_y,
                              highlight_atoms=[reaction_cores[0][r_ind]]))
        else:
            mol_images.append(draw_molecule(reactant, im_size_x, im_size_y))

        if r_ind == len(reactants) - 1:
            mol_images.append(Image.open("assets/arrow.png"))
        else:
            mol_images.append(Image.open("assets/plus.png"))

    # If specified, draw all agent molecules in similar fashion as the reactants.
    if len(reagents) > 0 and show_reagents:
        for rg_ind, reagent in enumerate(reagents):
            mol_images.append(draw_molecule(reagent, im_size_x, im_size_y))
            if rg_ind == len(reagents) - 1:
                mol_images.append(Image.open("assets/arrow.png"))
            else:
                mol_images.append(Image.open("assets/plus.png"))

    # Draw all product molecules.
    for p_ind, product in enumerate(products):
        if len(reaction_cores[1]) > 0:
            mol_images.append(
                draw_molecule(product,
                              im_size_x,
                              im_size_y,
                              highlight_atoms=[reaction_cores[1][p_ind]]))
        else:
            mol_images.append(
                draw_molecule(product,
                              im_size_x,
                              im_size_y,
                              highlight_atoms=[]))
        if p_ind != len(products) - 1:
            mol_images.append(Image.open("assets/plus.png"))

    # Adjust the widths and the heights of the images and generate the final images.
    widths, heights = zip(*(i.size for i in mol_images))
    total_width = sum(widths)
    max_height = max(heights)
    new_im = Image.new("RGB", (total_width, max_height), (255, 255, 255))

    # Calculate the height and width offsets for the smaller '+' and '->' images and append everything into a single
    # image representing the reaction.
    x_offset, y_offset = 0, 0
    for ind, im in enumerate(mol_images):
        if ind % 2 != 0:
            y_offset = round(im_size_y / 2 - im.size[1] / 2)
        else:
            y_offset = 0

        new_im.paste(im, (x_offset, y_offset))
        x_offset += im.size[0]

    # Return the newly created image.
    return new_im