Ejemplo n.º 1
0
def checksymmetric(lig):
    symmetric = True
    ligand, emsg = lig_load(lig)
    ligand.convert2mol3D()
    ligands_dict = getlicores()
    connecting_atoms = ligands_dict[lig][2]

    # Each bonding atom will be represented as a bonding enviroment. This
    # is a list of lists, where each individual list corresponds to atoms
    # a certain number of bonds away from the bonding atom. The
    # bonding_atom_environments variable holds a list of bonding environemnts,
    # one for each bonding atom.
    bonding_atom_environments = []
    for atom in connecting_atoms:
        index = int(atom)
        coordination_spheres = [[index]]
        used_atoms = {index}

        finding_atoms = True
        while finding_atoms == True:
            current_sphere = coordination_spheres[-1]
            length = len(coordination_spheres)

            next_sphere = set([])
            for atoms in current_sphere:
                # get a set containing elements from both sets
                next_sphere = next_sphere | set(ligand.getBondedAtoms(atoms))

            next_sphere = next_sphere - used_atoms  # subtracting sets
            used_atoms = used_atoms | next_sphere
            if list(next_sphere):
                coordination_spheres.append(list(next_sphere))

            if length == len(coordination_spheres):
                finding_atoms = False
        bonding_atom_environments.append(coordination_spheres)

    bonding_atom_environments0 = deepcopy(bonding_atom_environments)
    # Change the list of atoms from atom indices to atom names, also sort them
    for counter1, bonding_atoms in enumerate(bonding_atom_environments0):
        for counter2, sphere in enumerate(bonding_atoms):
            for counter3, atom in enumerate(sphere):
                atom = ligand.getAtom(int(atom))
                bonding_atom_environments0[counter1][counter2][
                    counter3] = atom.name
            sphere.sort()
    if bonding_atom_environments0[0] == bonding_atom_environments0[1]:
        symmetric = True
    else:
        symmetric = False
    return symmetric
Ejemplo n.º 2
0
def findshape(args, master_ligand):
    """Determines the relative positioning of different ligating atoms
        
    Parameters
    ----------
        args : Namespace
            Namespace argument from inparse.
        master_ligand : mol3D
            mol3D class instance of metal with the ligand.
        
    Returns
    -------
        angles_dict : dict
            A dictionary of angles (in degrees)between catoms.

    """
    core = loadcoord(args.geometry)

    # load ligands and identify the denticity of each
    ligands = []
    for counter, i in enumerate(args.lig):
        ligands.append(lig_load(i)[0])
    number_of_smiles_ligands = 0
    for counter, lig in enumerate(ligands):
        if lig.ident == 'smi':
            ligands[counter].denticity = len(
                args.smicat[number_of_smiles_ligands])

    bind = 1
    for counter, i in enumerate(ligands):
        if i.name == master_ligand.name:
            master_denticity = i.denticity
            break
        else:
            bind += 1 * int(args.ligocc[counter]) * int(i.denticity)
    binding_locations = (np.array(list(range(master_denticity)))) + bind

    metal_coords = np.array(core[0])
    ligating_coords = []
    for i in binding_locations:
        ligating_coords.append(np.array(core[i]))

    angles_dict = dict()
    for i in range(len(ligating_coords)):
        for j in range(len(ligating_coords)):
            angles_dict[str(i) + '-' + str(j)] = inverseCosRule(
                ligating_coords[i], metal_coords, ligating_coords[j])
    return angles_dict
Ejemplo n.º 3
0
def draw_supervisor(args, rundir):
    if args.lig:
        print('Due to technical limitations, we will draw only the first ligand.')
        print('To view multiple ligands at once, consider using the GUI instead.')
        l = args.lig[0]
        lig, emsg = lig_load(l)
        lig.draw_svg(l)
    elif args.core:
        if len(args.core) > 1:
            print('Due to technical limitations, we will draw only the first core.')
        print('Drawing the core.')
        if args.substrate:
            print('Due to technical limitations, we can draw only one structure per run. To draw the substrate, run the program again.')
        cc, emsg = core_load(args.core[0])
        cc.draw_svg(args.core[0])
    elif args.substrate:
        if len(args.substrate) > 1:
            print('Due to technical limitations, we will draw only the first substrate.')
        print('Drawing the substrate.')
        print((args.substrate[0]))
        substrate, emsg = substr_load(args.substrate[0])
        substrate.draw_svg(args.substrate[0])
    else:
        print('You have not specified anything to draw. Currently supported: ligand, core, substrate')
Ejemplo n.º 4
0
def tf_ANN_preproc(args, ligs, occs, dents, batslist, tcats, licores):
    # prepares and runs ANN calculation

    current_time = time.time()
    start_time = current_time
    last_time = current_time

    ######################
    ANN_reason = {}
    ANN_attributes = {}
    ######################

    r = 0
    emsg = list()
    valid = True
    catalysis = False
    metal = args.core
    this_metal = metal.lower()
    if len(this_metal) > 2:
        this_metal = this_metal[0:2]
    newligs = []
    newcats = []
    newdents = []
    newoccs = []
    newdecs = [False] * 6
    newdec_inds = [[]] * 6
    ANN_trust = False
    count = -1
    for i, lig in enumerate(ligs):
        this_occ = occs[i]
        if args.debug:
            print(('working on lig: ' + str(lig)))
            print(('occ is  ' + str(this_occ)))
        for j in range(0, int(this_occ)):
            count += 1
            newligs.append(lig)
            newdents.append(dents[i])
            newcats.append(tcats[i])
            newoccs.append(1)
            if args.decoration:
                newdecs[count] = (args.decoration[i])
                newdec_inds[count] = (args.decoration_index[i])

    ligs = newligs
    dents = newdents
    tcats = newcats
    occs = newoccs
    if args.debug:
        print('tf_nn has finisihed prepping ligands')

    if not args.geometry == "oct":
        emsg.append(
            "[ANN] Geometry is not supported at this time, MUST give -geometry = oct"
        )
        valid = False
        ANN_reason = 'geometry not oct'
    if not args.oxstate:
        emsg.append("\n oxidation state must be given")
        valid = False
        ANN_reason = 'oxstate not given'
    if valid:
        oxidation_state = args.oxstate
        valid, oxidation_state = check_metal(this_metal, oxidation_state)
        if int(oxidation_state) in [3, 4, 5]:
            catalytic_moieties = ['oxo', 'x', 'hydroxyl', '[O--]', '[OH-]']
            if args.debug:
                print(('the ligands are', ligs))
                print((set(ligs).intersection(set(catalytic_moieties))))
            if len(set(ligs).intersection(set(catalytic_moieties))) > 0:
                catalysis = True
        # generate key in descriptor space
        ox = int(oxidation_state)
        spin = args.spin
        if args.debug:
            print(('metal is ' + str(this_metal)))
            print(('metal validity', valid))
    if not valid and not catalysis:
        emsg.append("\n Oxidation state not available for this metal")
        ANN_reason = 'ox state not available for metal'
    if valid:
        high_spin, spin_ops = spin_classify(this_metal, spin, ox)
    if not valid and not catalysis:
        emsg.append("\n this spin state not available for this metal")
        ANN_reason = 'spin state not available for metal'
    if emsg:
        print((str(" ".join(["ANN messages:"] + [str(i) for i in emsg]))))

    current_time = time.time()
    metal_check_time = current_time - last_time
    last_time = current_time
    if args.debug:
        print(('checking metal/ox took  ' +
               "{0:.2f}".format(metal_check_time) + ' seconds'))

    if valid or catalysis:
        (valid, axial_ligs, equitorial_ligs, ax_dent, eq_dent, ax_tcat,
         eq_tcat, axial_ind_list, equitorial_ind_list, ax_occs, eq_occs,
         pentadentate) = tf_check_ligands(ligs, batslist, dents, tcats, occs,
                                          args.debug)

        if args.debug:
            print(("ligand validity is  " + str(valid)))
            print(('Occs', occs))
            print(('Ligands', ligs))
            print(('Dents', dents))
            print(('Bats (backbone atoms)', batslist))
            print(('lig validity', valid))
            print(('ax ligs', axial_ligs))
            print(('eq ligs', equitorial_ligs))
            print(('spin is', spin))

        if catalysis:
            valid = False
    if (not valid) and (not catalysis):
        ANN_reason = 'found incorrect ligand symmetry'
    elif not valid and catalysis:
        if args.debug:
            print('tf_nn detects catalytic')
        ANN_reason = 'catalytic structure presented'

    # placeholder for metal
    metal_mol = mol3D()
    metal_mol.addAtom(atom3D(metal))

    net_lig_charge = 0
    if valid or catalysis:
        if args.debug:
            print('loading axial ligands')
        ax_ligands_list = list()
        eq_ligands_list = list()
        for ii, axl in enumerate(axial_ligs):
            ax_lig3D, r_emsg = lig_load(axl, licores)  # load ligand
            net_lig_charge += ax_lig3D.charge
            if r_emsg:
                emsg += r_emsg
            if ax_tcat:
                ax_lig3D.cat = ax_tcat
                if args.debug:
                    print(('custom ax connect atom given (0-ind) ' +
                           str(ax_tcat)))
            if pentadentate and len(ax_lig3D.cat) > 1:
                ax_lig3D.cat = [ax_lig3D.cat[-1]]
            this_lig = ligand(mol3D(), [], ax_dent)
            this_lig.mol = ax_lig3D

            # check decoration index
            if newdecs:
                if newdecs[axial_ind_list[ii]]:
                    print(('decorating ' + str(axl) + ' with ' +
                           str(newdecs[axial_ind_list[ii]]) + ' at sites ' +
                           str(newdec_inds[axial_ind_list[ii]])))
                    ax_lig3D = decorate_ligand(args, axl,
                                               newdecs[axial_ind_list[ii]],
                                               newdec_inds[axial_ind_list[ii]])
            ax_lig3D.convert2mol3D()  # mol3D representation of ligand
            for jj in range(0, ax_occs[ii]):
                ax_ligands_list.append(this_lig)
        print(('Obtained the net ligand charge, which is... ', net_lig_charge))
        if args.debug:
            print('ax_ligands_list:')
            print(ax_ligands_list)
            print([h.mol.cat for h in ax_ligands_list])

        if args.debug:
            print(('loading equitorial ligands ' + str(equitorial_ligs)))
        for ii, eql in enumerate(equitorial_ligs):
            eq_lig3D, r_emsg = lig_load(eql, licores)  # load ligand
            net_lig_charge += eq_lig3D.charge
            if r_emsg:
                emsg += r_emsg
            if eq_tcat:
                eq_lig3D.cat = eq_tcat
                if args.debug:
                    print(('custom eq connect atom given (0-ind) ' +
                           str(eq_tcat)))
            if pentadentate and len(eq_lig3D.cat) > 1:
                eq_lig3D.cat = eq_lig3D.cat[0:4]

            if newdecs:
                if args.debug:
                    print(('newdecs' + str(newdecs)))
                    print(
                        ('equitorial_ind_list is ' + str(equitorial_ind_list)))
                c = 0
                if newdecs[equitorial_ind_list[ii]]:
                    if args.debug:
                        print(('decorating ' + str(eql) + ' with ' +
                               str(newdecs[equitorial_ind_list[ii]]) +
                               ' at sites ' +
                               str(newdec_inds[equitorial_ind_list[ii]])))
                    eq_lig3D = decorate_ligand(
                        args, eql, newdecs[equitorial_ind_list[ii]],
                        newdec_inds[equitorial_ind_list[ii]])
                    c += 1

            eq_lig3D.convert2mol3D()  # mol3D representation of ligand
            this_lig = ligand(mol3D(), [], eq_dent)
            this_lig.mol = eq_lig3D

            for jj in range(0, eq_occs[ii]):
                eq_ligands_list.append(this_lig)
        if args.debug:
            print('eq_ligands_list:')
            print(eq_ligands_list)

            current_time = time.time()
            ligand_check_time = current_time - last_time
            last_time = current_time
            print(('checking ligs took ' +
                   "{0:.2f}".format(ligand_check_time) + ' seconds'))
            print(
                ('writing copies of ligands as used  in ANN to currrent dir : '
                 + os.getcwd()))
            for kk, l in enumerate(ax_ligands_list):
                l.mol.writexyz('axlig-' + str(kk) + '.xyz')
            for kk, l in enumerate(eq_ligands_list):
                l.mol.writexyz('eqlig-' + str(kk) + '.xyz')
        # make description of complex
        custom_ligand_dict = {
            "eq_ligand_list": eq_ligands_list,
            "ax_ligand_list": ax_ligands_list,
            "eq_con_int_list": [h.mol.cat for h in eq_ligands_list],
            "ax_con_int_list": [h.mol.cat for h in ax_ligands_list]
        }

        ox_modifier = {metal: ox}

        this_complex = assemble_connectivity_from_parts(
            metal_mol, custom_ligand_dict)

        if args.debug:
            print('custom_ligand_dict is : ')
            print(custom_ligand_dict)

    if args.debug:
        print(('finished checking ligands, valid is ' + str(valid)))
        print('assembling RAC custom ligand configuration dictionary')

    if valid:
        # =====Classifiers:=====
        _descriptor_names = ["oxstate", "spinmult", "charge_lig"]
        _descriptors = [ox, spin, net_lig_charge]
        descriptor_names, descriptors = get_descriptor_vector(
            this_complex, custom_ligand_dict, ox_modifier)
        descriptor_names = _descriptor_names + descriptor_names
        descriptors = _descriptors + descriptors
        flag_oct, geo_lse = ANN_supervisor("geo_static_clf",
                                           descriptors,
                                           descriptor_names,
                                           debug=args.debug)
        # Test for scikit-learn models
        # flag_oct, geo_lse = sklearn_supervisor("geo_static_clf", descriptors, descriptor_names, debug=False)
        sc_pred, sc_lse = ANN_supervisor("sc_static_clf",
                                         descriptors,
                                         descriptor_names,
                                         debug=args.debug)
        ANN_attributes.update({
            "geo_label": 0 if flag_oct[0, 0] <= 0.5 else 1,
            "geo_prob": flag_oct[0, 0],
            "geo_LSE": geo_lse[0],
            "geo_label_trust": lse_trust(geo_lse),
            "sc_label": 0 if sc_pred[0, 0] <= 0.5 else 1,
            "sc_prob": sc_pred[0, 0],
            "sc_LSE": sc_lse[0],
            "sc_label_trust": lse_trust(sc_lse)
        })

        # build RACs without geo
        con_mat = this_complex.graph
        descriptor_names, descriptors = get_descriptor_vector(
            this_complex, custom_ligand_dict, ox_modifier)

        # get one-hot-encoding (OHE)
        descriptor_names, descriptors = create_OHE(descriptor_names,
                                                   descriptors, metal,
                                                   oxidation_state)

        # get alpha
        alpha = 0.2  # default for B3LYP
        if args.exchange:
            try:
                if float(args.exchange) > 1:
                    alpha = float(args.exchange) / 100  # if given as %
                elif float(args.exchange) <= 1:
                    alpha = float(args.exchange)
            except:
                print('cannot cast exchange argument as a float, using 20%')
        descriptor_names += ['alpha']
        descriptors += [alpha]
        descriptor_names += ['ox']
        descriptors += [ox]
        descriptor_names += ['spin']
        descriptors += [spin]
        if args.debug:
            current_time = time.time()
            rac_check_time = current_time - last_time
            last_time = current_time
            print(('getting RACs took ' + "{0:.2f}".format(rac_check_time) +
                   ' seconds'))

        # get spin splitting:
        split, latent_split = ANN_supervisor('split', descriptors,
                                             descriptor_names, args.debug)
        if args.debug:
            current_time = time.time()
            split_ANN_time = current_time - last_time
            last_time = current_time
            print(('split ANN took ' + "{0:.2f}".format(split_ANN_time) +
                   ' seconds'))

        # get bond lengths:
        if oxidation_state == '2':
            r_ls, latent_r_ls = ANN_supervisor('ls_ii', descriptors,
                                               descriptor_names, args.debug)
            r_hs, latent_r_hs = ANN_supervisor('hs_ii', descriptors,
                                               descriptor_names, args.debug)
        elif oxidation_state == '3':
            r_ls, latent_r_ls = ANN_supervisor('ls_iii', descriptors,
                                               descriptor_names, args.debug)
            r_hs, latent_r_hs = ANN_supervisor('hs_iii', descriptors,
                                               descriptor_names, args.debug)
        if not high_spin:
            r = r_ls[0]
        else:
            r = r_hs[0]

        if args.debug:
            current_time = time.time()
            GEO_ANN_time = current_time - last_time
            last_time = current_time
            print(('GEO ANN took ' + "{0:.2f}".format(GEO_ANN_time) +
                   ' seconds'))

        h**o, latent_homo = ANN_supervisor('h**o', descriptors,
                                           descriptor_names, args.debug)
        if args.debug:
            current_time = time.time()
            homo_ANN_time = current_time - last_time
            last_time = current_time
            print(('h**o ANN took ' + "{0:.2f}".format(homo_ANN_time) +
                   ' seconds'))

        gap, latent_gap = ANN_supervisor('gap', descriptors, descriptor_names,
                                         args.debug)
        if args.debug:
            current_time = time.time()
            gap_ANN_time = current_time - last_time
            last_time = current_time
            print(('gap ANN took ' + "{0:.2f}".format(gap_ANN_time) +
                   ' seconds'))

        # get minimum distance to train (for splitting)

        split_dist = find_true_min_eu_dist("split", descriptors,
                                           descriptor_names)
        if args.debug:
            current_time = time.time()
            min_dist_time = current_time - last_time
            last_time = current_time
            print(('min dist took ' + "{0:.2f}".format(min_dist_time) +
                   ' seconds'))

        homo_dist = find_true_min_eu_dist("h**o", descriptors,
                                          descriptor_names)
        homo_dist = find_ANN_latent_dist("h**o", latent_homo, args.debug)
        if args.debug:
            current_time = time.time()
            min_dist_time = current_time - last_time
            last_time = current_time
            print(('min H**O dist took ' + "{0:.2f}".format(min_dist_time) +
                   ' seconds'))

        gap_dist = find_true_min_eu_dist("gap", descriptors, descriptor_names)
        gap_dist = find_ANN_latent_dist("gap", latent_gap, args.debug)
        if args.debug:
            current_time = time.time()
            min_dist_time = current_time - last_time
            last_time = current_time
            print(('min GAP dist took ' + "{0:.2f}".format(min_dist_time) +
                   ' seconds'))

        # save attributes for return
        ANN_attributes.update({'split': split[0][0]})
        ANN_attributes.update({'split_dist': split_dist})
        ANN_attributes.update({'This spin': spin})
        if split[0][0] < 0 and (abs(split[0]) > 5):
            ANN_attributes.update({'ANN_ground_state': spin_ops[1]})
        elif split[0][0] > 0 and (abs(split[0]) > 5):
            ANN_attributes.update({'ANN_ground_state': spin_ops[0]})
        else:
            ANN_attributes.update(
                {'ANN_ground_state': 'dgen ' + str(spin_ops)})

        ANN_attributes.update({'h**o': h**o[0][0]})
        ANN_attributes.update({'gap': gap[0][0]})
        ANN_attributes.update({'homo_dist': homo_dist})
        ANN_attributes.update({'gap_dist': gap_dist})

        # now that we have bond predictions, we need to map these
        # back to a length of equal size as the original ligand request
        # in order for molSimplify to understand if
        ANN_bondl = len(ligs) * [False]
        added = 0
        for ii, eql in enumerate(equitorial_ind_list):
            for jj in range(0, eq_occs[ii]):
                ANN_bondl[added] = r[2]
                added += 1

        for ii, axl in enumerate(axial_ind_list):
            if args.debug:
                print((ii, axl, added, ax_occs))
            for jj in range(0, ax_occs[ii]):
                if args.debug:
                    print((jj, axl, added, r[ii]))
                ANN_bondl[added] = r[ii]
                added += 1

        ANN_attributes.update({'ANN_bondl': 4 * [r[2]] + [r[0], r[1]]})

        HOMO_ANN_trust = 'not set'
        HOMO_ANN_trust_message = ""
        # Not quite sure if this should be divided by 3 or not, since RAC-155 descriptors
        if float(homo_dist) < 3:
            HOMO_ANN_trust_message = 'ANN results should be trustworthy for this complex '
            HOMO_ANN_trust = 'high'
        elif float(homo_dist) < 5:
            HOMO_ANN_trust_message = 'ANN results are probably useful for this complex '
            HOMO_ANN_trust = 'medium'
        elif float(homo_dist) <= 10:
            HOMO_ANN_trust_message = 'ANN results are fairly far from training data, be cautious '
            HOMO_ANN_trust = 'low'
        elif float(homo_dist) > 10:
            HOMO_ANN_trust_message = 'ANN results are too far from training data, be cautious '
            HOMO_ANN_trust = 'very low'
        ANN_attributes.update({'homo_trust': HOMO_ANN_trust})
        ANN_attributes.update({'gap_trust': HOMO_ANN_trust})

        ANN_trust = 'not set'
        ANN_trust_message = ""
        if float(split_dist / 3) < 0.25:
            ANN_trust_message = 'ANN results should be trustworthy for this complex '
            ANN_trust = 'high'
        elif float(split_dist / 3) < 0.75:
            ANN_trust_message = 'ANN results are probably useful for this complex '
            ANN_trust = 'medium'
        elif float(split_dist / 3) < 1.0:
            ANN_trust_message = 'ANN results are fairly far from training data, be cautious '
            ANN_trust = 'low'
        elif float(split_dist / 3) > 1.0:
            ANN_trust_message = 'ANN results are too far from training data, be cautious '
            ANN_trust = 'very low'
        ANN_attributes.update({'split_trust': ANN_trust})

        # print text to std out
        print(
            "******************************************************************"
        )
        print(
            "************** ANN is engaged and advising on spin ***************"
        )
        print(
            "************** and metal-ligand bond distances    ****************"
        )
        print(
            "******************************************************************"
        )
        if high_spin:
            print(('You have selected a high-spin state, s = ' + str(spin)))
        else:
            print(('You have selected a low-spin state, s = ' + str(spin)))
        # report to stdout
        if split[0] < 0 and not high_spin:
            if abs(split[0]) > 5:
                print(
                    'warning, ANN predicts a high spin ground state for this complex'
                )
            else:
                print(
                    'warning, ANN predicts a near degenerate ground state for this complex'
                )
        elif split[0] >= 0 and high_spin:
            if abs(split[0]) > 5:
                print(
                    'warning, ANN predicts a low spin ground state for this complex'
                )
            else:
                print(
                    'warning, ANN predicts a near degenerate ground state for this complex'
                )
        print(('delta is', split[0], ' spin is ', high_spin))
        print(("ANN predicts a spin splitting (HS - LS) of " +
               "{0:.2f}".format(float(split[0])) + ' kcal/mol at ' +
               "{0:.0f}".format(100 * alpha) + '% HFX'))
        print(('ANN low spin bond length (ax1/ax2/eq) is predicted to be: ' +
               " /".join(["{0:.2f}".format(float(i))
                          for i in r_ls[0]]) + ' angstrom'))
        print(('ANN high spin bond length (ax1/ax2/eq) is predicted to be: ' +
               " /".join(["{0:.2f}".format(float(i))
                          for i in r_hs[0]]) + ' angstrom'))
        print(('distance to splitting energy training data is ' +
               "{0:.2f}".format(split_dist)))
        print(ANN_trust_message)
        print(("ANN predicts a H**O value of " +
               "{0:.2f}".format(float(h**o[0])) + ' eV at ' +
               "{0:.0f}".format(100 * alpha) + '% HFX'))
        print(("ANN predicts a LUMO-H**O energetic gap value of " +
               "{0:.2f}".format(float(gap[0])) + ' eV at ' +
               "{0:.0f}".format(100 * alpha) + '% HFX'))
        print(HOMO_ANN_trust_message)
        print(('distance to H**O training data is ' +
               "{0:.2f}".format(homo_dist)))
        print(
            ('distance to GAP training data is ' + "{0:.2f}".format(gap_dist)))
        print(
            "*******************************************************************"
        )
        print(
            "************** ANN complete, saved in record file *****************"
        )
        print(
            "*******************************************************************"
        )
        from keras import backend as K
        # This is done to get rid of the attribute error that is a bug in tensorflow.
        K.clear_session()
        current_time = time.time()
        total_ANN_time = current_time - start_time
        last_time = current_time
        print(('Total ML functions took ' + "{0:.2f}".format(total_ANN_time) +
               ' seconds'))

    if catalysis:
        print('-----In Catalysis Mode-----')
        # build RACs without geo
        con_mat = this_complex.graph
        descriptor_names, descriptors = get_descriptor_vector(
            this_complex, custom_ligand_dict, ox_modifier)
        # get alpha
        alpha = 20  # default for B3LYP
        if args.exchange:
            try:
                if float(args.exchange) < 1:
                    alpha = float(args.exchange) * 100  # if given as %
                elif float(args.exchange) >= 1:
                    alpha = float(args.exchange)
            except:
                print('cannot case exchange argument as a float, using 20%')
        descriptor_names += ['alpha', 'ox', 'spin', 'charge_lig']
        descriptors += [alpha, ox, spin, net_lig_charge]
        if args.debug:
            current_time = time.time()
            rac_check_time = current_time - last_time
            last_time = current_time
            print(('getting RACs took ' + "{0:.2f}".format(rac_check_time) +
                   ' seconds'))
        oxo, latent_oxo = ANN_supervisor('oxo', descriptors, descriptor_names,
                                         args.debug)
        if args.debug:
            current_time = time.time()
            split_ANN_time = current_time - last_time
            last_time = current_time
        oxo_dist, avg_10_NN_dist, avg_traintrain = find_ANN_10_NN_normalized_latent_dist(
            "oxo", latent_oxo, args.debug)
        if args.debug:
            current_time = time.time()
            min_dist_time = current_time - last_time
            last_time = current_time
            print(('min oxo dist took ' + "{0:.2f}".format(min_dist_time) +
                   ' seconds'))
        ANN_attributes.update({'oxo': oxo[0][0]})
        ANN_attributes.update({'oxo_dist': oxo_dist})

        hat, latent_hat = ANN_supervisor('hat', descriptors, descriptor_names,
                                         args.debug)
        if args.debug:
            current_time = time.time()
            split_ANN_time = current_time - last_time
            last_time = current_time
            print(('HAT ANN took ' + "{0:.2f}".format(split_ANN_time) +
                   ' seconds'))

        hat_dist, avg_10_NN_dist, avg_traintrain = find_ANN_10_NN_normalized_latent_dist(
            "hat", latent_hat, args.debug)
        if args.debug:
            current_time = time.time()
            min_dist_time = current_time - last_time
            last_time = current_time
            print(('min hat dist took ' + "{0:.2f}".format(min_dist_time) +
                   ' seconds'))
        ANN_attributes.update({'hat': hat[0][0]})
        ANN_attributes.update({'hat_dist': hat_dist})

        ########## for Oxo and H**O optimization ##########
        oxo20, latent_oxo20 = ANN_supervisor('oxo20', descriptors,
                                             descriptor_names, args.debug)
        if args.debug:
            current_time = time.time()
            oxo20_ANN_time = current_time - last_time
            last_time = current_time
            print(('oxo20 ANN took ' + "{0:.2f}".format(oxo20_ANN_time) +
                   ' seconds'))
        # oxo20_dist = find_ANN_latent_dist("oxo20", latent_oxo20, args.debug)
        oxo20_dist, avg_10_NN_dist, avg_traintrain = find_ANN_10_NN_normalized_latent_dist(
            "oxo20", latent_oxo20, args.debug)
        if args.debug:
            current_time = time.time()
            min_dist_time = current_time - last_time
            last_time = current_time
            print(('min oxo20 dist took ' + "{0:.2f}".format(min_dist_time) +
                   ' seconds'))
        ANN_attributes.update({'oxo20': oxo20[0][0]})
        ANN_attributes.update({'oxo20_dist': oxo20_dist})
        # _ = find_ANN_latent_dist("oxo20", latent_oxo20, args.debug)
        # _ = find_true_min_eu_dist("oxo20", descriptors, descriptor_names, latent_space_vector=latent_oxo20)

        homo_empty, latent_homo_empty = ANN_supervisor('homo_empty',
                                                       descriptors,
                                                       descriptor_names,
                                                       args.debug)
        if args.debug:
            current_time = time.time()
            homo_empty_ANN_time = current_time - last_time
            last_time = current_time
            print(('homo_empty ANN took ' +
                   "{0:.2f}".format(homo_empty_ANN_time) + ' seconds'))
        # homo_empty_dist = find_ANN_latent_dist("homo_empty", latent_homo_empty, args.debug)
        homo_empty_dist, avg_10_NN_dist, avg_traintrain = find_ANN_10_NN_normalized_latent_dist(
            "homo_empty", latent_homo_empty, args.debug)
        if args.debug:
            current_time = time.time()
            min_dist_time = current_time - last_time
            last_time = current_time
            print(('min homo_empty dist took ' +
                   "{0:.2f}".format(min_dist_time) + ' seconds'))
        ANN_attributes.update({'homo_empty': homo_empty[0][0]})
        ANN_attributes.update({'homo_empty_dist': homo_empty_dist})
        # _ = find_ANN_latent_dist("homo_empty", latent_homo_empty, args.debug)
        # _ = find_true_min_eu_dist("homo_empty", descriptors, descriptor_names, latent_space_vector=latent_homo_empty)

        Oxo20_ANN_trust = 'not set'
        Oxo20_ANN_trust_message = ""
        # Not quite sure if this should be divided by 3 or not, since RAC-155 descriptors
        if float(oxo20_dist) < 0.75:
            Oxo20_ANN_trust_message = 'Oxo20 ANN results should be trustworthy for this complex '
            Oxo20_ANN_trust = 'high'
        elif float(oxo20_dist) < 1:
            Oxo20_ANN_trust_message = 'Oxo20 ANN results are probably useful for this complex '
            Oxo20_ANN_trust = 'medium'
        elif float(oxo20_dist) <= 1.25:
            Oxo20_ANN_trust_message = 'Oxo20 ANN results are fairly far from training data, be cautious '
            Oxo20_ANN_trust = 'low'
        elif float(oxo20_dist) > 1.25:
            Oxo20_ANN_trust_message = 'Oxo20 ANN results are too far from training data, be cautious '
            Oxo20_ANN_trust = 'very low'
        ANN_attributes.update({'oxo20_trust': Oxo20_ANN_trust})

        homo_empty_ANN_trust = 'not set'
        homo_empty_ANN_trust_message = ""
        # Not quite sure if this should be divided by 3 or not, since RAC-155 descriptors
        if float(homo_empty_dist) < 0.75:
            homo_empty_ANN_trust_message = 'homo_empty ANN results should be trustworthy for this complex '
            homo_empty_ANN_trust = 'high'
        elif float(homo_empty_dist) < 1:
            homo_empty_ANN_trust_message = 'homo_empty ANN results are probably useful for this complex '
            homo_empty_ANN_trust = 'medium'
        elif float(homo_empty_dist) <= 1.25:
            homo_empty_ANN_trust_message = 'homo_empty ANN results are fairly far from training data, be cautious '
            homo_empty_ANN_trust = 'low'
        elif float(homo_empty_dist) > 1.25:
            homo_empty_ANN_trust_message = 'homo_empty ANN results are too far from training data, be cautious '
            homo_empty_ANN_trust = 'very low'
        ANN_attributes.update({'homo_empty_trust': homo_empty_ANN_trust})

        ####################################################

        Oxo_ANN_trust = 'not set'
        Oxo_ANN_trust_message = ""
        # Not quite sure if this should be divided by 3 or not, since RAC-155 descriptors
        if float(oxo_dist) < 3:
            Oxo_ANN_trust_message = 'Oxo ANN results should be trustworthy for this complex '
            Oxo_ANN_trust = 'high'
        elif float(oxo_dist) < 5:
            Oxo_ANN_trust_message = 'Oxo ANN results are probably useful for this complex '
            Oxo_ANN_trust = 'medium'
        elif float(oxo_dist) <= 10:
            Oxo_ANN_trust_message = 'Oxo ANN results are fairly far from training data, be cautious '
            Oxo_ANN_trust = 'low'
        elif float(oxo_dist) > 10:
            Oxo_ANN_trust_message = 'Oxo ANN results are too far from training data, be cautious '
            Oxo_ANN_trust = 'very low'
        ANN_attributes.update({'oxo_trust': Oxo_ANN_trust})

        HAT_ANN_trust = 'not set'
        HAT_ANN_trust_message = ""
        # Not quite sure if this should be divided by 3 or not, since RAC-155 descriptors
        if float(hat_dist) < 3:
            HAT_ANN_trust_message = 'HAT ANN results should be trustworthy for this complex '
            HAT_ANN_trust = 'high'
        elif float(hat_dist) < 5:
            HAT_ANN_trust_message = 'HAT ANN results are probably useful for this complex '
            HAT_ANN_trust = 'medium'
        elif float(hat_dist) <= 10:
            HAT_ANN_trust_message = 'HAT ANN results are fairly far from training data, be cautious '
            HAT_ANN_trust = 'low'
        elif float(hat_dist) > 10:
            HAT_ANN_trust_message = 'HAT ANN results are too far from training data, be cautious '
            HAT_ANN_trust = 'very low'
        ANN_attributes.update({'hat_trust': HAT_ANN_trust})
        print(
            "*******************************************************************"
        )
        print(
            "**************       CATALYTIC ANN ACTIVATED!      ****************"
        )
        print(
            "*********** Currently advising on Oxo and HAT energies ************"
        )
        print(
            "*******************************************************************"
        )
        print(("ANN predicts a Oxo20 energy of " +
               "{0:.2f}".format(float(oxo20[0])) + ' kcal/mol at ' +
               "{0:.2f}".format(alpha) + '% HFX'))
        print(Oxo20_ANN_trust_message)
        print(('Distance to Oxo20 training data in the latent space is ' +
               "{0:.2f}".format(oxo20_dist)))
        print(("ANN predicts a empty site beta H**O level of " +
               "{0:.2f}".format(float(homo_empty[0])) + ' eV at ' +
               "{0:.2f}".format(alpha) + '% HFX'))
        print(homo_empty_ANN_trust_message)
        print((
            'Distance to empty site beta H**O level training data in the latent space is '
            + "{0:.2f}".format(homo_empty_dist)))
        print(
            '-------------------------------------------------------------------'
        )
        print(("ANN predicts a oxo formation energy of " +
               "{0:.2f}".format(float(oxo[0])) + ' kcal/mol at ' +
               "{0:.2f}".format(alpha) + '% HFX'))
        print(Oxo_ANN_trust_message)
        print(('Distance to oxo training data in the latent space is ' +
               "{0:.2f}".format(oxo_dist)))
        print(("ANN predicts a HAT energy of " +
               "{0:.2f}".format(float(hat[0])) + ' kcal/mol at ' +
               "{0:.2f}".format(alpha) + '% HFX'))
        print(HAT_ANN_trust_message)
        print(('Distance to HAT training data in the latent space is ' +
               "{0:.2f}".format(hat_dist)))
        print(
            "*******************************************************************"
        )
        print(
            "************** ANN complete, saved in record file *****************"
        )
        print(
            "*******************************************************************"
        )
        from keras import backend as K
        # This is done to get rid of the attribute error that is a bug in tensorflow.
        K.clear_session()

    if catalysis:
        current_time = time.time()
        total_ANN_time = current_time - start_time
        last_time = current_time
        print(('Total Catalysis ML functions took ' +
               "{0:.2f}".format(total_ANN_time) + ' seconds'))

    if not valid and not ANN_reason and not catalysis:
        ANN_reason = ' uncaught rejection (see sdout/stderr)'

    return valid, ANN_reason, ANN_attributes, catalysis

    if False:
        # test Euclidean norm to training data distance
        train_dist, best_row = find_eu_dist(nn_excitation)
        ANN_trust = max(0.01, 1.0 - train_dist)

        ANN_attributes.update({'ANN_closest_train': best_row})

        print((' with closest training row ' + best_row[:-2] + ' at  ' +
               str(best_row[-2:]) + '% HFX'))

        # use ANN to predict fucntional sensitivty
        HFX_slope = 0
        HFX_slope = get_slope(slope_excitation)
        print(('Predicted HFX exchange sensitivity is : ' +
               "{0:.2f}".format(float(HFX_slope)) + ' kcal/HFX'))
        ANN_attributes.update({'ANN_slope': HFX_slope})
Ejemplo n.º 5
0
def ANN_preproc(args, ligs, occs, dents, batslist, tcats, licores):
    # prepares and runs ANN calculation

    ######################
    ANN_reason = False  # holder for reason to reject ANN call
    ANN_attributes = dict()
    ######################

    nn_excitation = []
    r = 0
    emsg = list()
    valid = True
    metal = args.core
    this_metal = metal.lower()
    if len(this_metal) > 2:
        this_metal = this_metal[0:2]
    newligs = []
    newcats = []
    newdents = []
    newdecs = [False] * 6
    newdec_inds = [[]] * 6
    ANN_trust = False
    count = -1
    for i, lig in enumerate(ligs):
        this_occ = occs[i]
        for j in range(0, int(this_occ)):
            count += 1
            newligs.append(lig)
            newdents.append(dents[i])
            newcats.append(tcats[i])
            if args.decoration:
                newdecs[count] = (args.decoration[i])
                newdec_inds[count] = (args.decoration_index[i])

    ligs = newligs
    dents = newdents
    tcats = newcats

    if not args.geometry == "oct":
        #        print('nn: geom  is',args.geometry)
        #        emsg.append("[ANN] Geometry is not supported at this time, MUST give -geometry = oct")
        valid = False
        ANN_reason = 'geometry not oct'
    if not args.oxstate:
        emsg.append("\n oxidation state must be given")
        valid = False
        ANN_reason = 'oxstate not given'
    if valid:
        oxidation_state = args.oxstate
        valid, oxidation_state = check_metal(this_metal, oxidation_state)
        # generate key in descriptor space
        ox = int(oxidation_state)
        spin = args.spin
        if args.debug:
            print(('metal is ' + str(this_metal)))
            print(('metal validity', valid))
        if not valid:
            emsg.append("\n Oxidation state not available for this metal")
            ANN_reason = 'ox state not avail for metal'
    if valid:
        high_spin, spin_ops = spin_classify(this_metal, spin, ox)
        if not valid:
            emsg.append("\n this spin state not available for this metal")
            ANN_reason = 'spin state not availble for metal'
    if emsg:
        print((str(" ".join(["ANN messages:"] + [str(i) for i in emsg]))))
    if valid:
        valid, axial_ligs, equitorial_ligs, ax_dent, eq_dent, ax_tcat, eq_tcat, axial_ind_list, equitorial_ind_list = check_ligands(
            ligs, batslist, dents, tcats)
        if args.debug:
            print("\n")
            print(("ligand validity is  " + str(valid)))
            print('Occs')
            print(occs)
            print('Ligands')
            print(ligs)
            print('Dents')
            print(dents)
            print('Bats (backbone atoms)')
            print(batslist)
            print(('lig validity', valid))
            print(('ax ligs', axial_ligs))
            print(('eq ligs', equitorial_ligs))
            print(('spin is', spin))
        if not valid:
            ANN_reason = 'found incorrect ligand symmetry'

    if valid:
        ax_lig3D, r_emsg = lig_load(axial_ligs[0], licores)  # load ligand
        if r_emsg:
            emsg += r_emsg
        # check decoration index
        if newdecs:
            if newdecs[axial_ind_list[0]]:
                #print('decorating ' + str(axial_ligs[0]) + ' with ' +str(newdecs[axial_ind_list[0]]) + ' at sites '  + str(newdec_inds[axial_ind_list[0]]))
                ax_lig3D = decorate_ligand(args, axial_ligs[0],
                                           newdecs[axial_ind_list[0]],
                                           newdec_inds[axial_ind_list[0]])

        ax_lig3D.convert2mol3D()  # mol3D representation of ligand
        # eq
        eq_lig3D, r_emsg = lig_load(equitorial_ligs[0], licores)  # load ligand
        if newdecs:
            if newdecs[equitorial_ind_list[0]]:
                #print('decorating ' + str(equitorial_ligs[0]) + ' with ' +str(newdecs[equitorial_ind_list[0]]) + ' at sites '  + str(newdec_inds[equitorial_ind_list[0]]))
                eq_lig3D = decorate_ligand(args, equitorial_ligs[0],
                                           newdecs[equitorial_ind_list[0]],
                                           newdec_inds[equitorial_ind_list[0]])
        if r_emsg:
            emsg += r_emsg
        eq_lig3D.convert2mol3D()  # mol3D representation of ligand
        if ax_tcat:
            ax_lig3D.cat = ax_tcat
            if args.debug:
                print(('custom ax connect atom given (0-ind) ' + str(ax_tcat)))
        if eq_tcat:
            eq_lig3D.cat = eq_tcat
            if args.debug:
                print(('custom eq connect atom given (0-ind) ' + str(eq_tcat)))
    if args.debug:
        print(('finished checking ligands, valid is ' + str(valid)))
    if valid:
        valid, ax_type = get_con_at_type(ax_lig3D, ax_lig3D.cat)
    if valid:
        valid, eq_type = get_con_at_type(eq_lig3D, eq_lig3D.cat)
        if args.debug:
            print(('finished con atom types ' + str(ax_type) + ' and ' +
                   str(eq_type)))

    if valid:
        eq_ki = get_truncated_kier(eq_lig3D, eq_lig3D.cat)
        ax_ki = get_truncated_kier(ax_lig3D, ax_lig3D.cat)
        eq_EN = get_lig_EN(eq_lig3D, eq_lig3D.cat)
        ax_EN = get_lig_EN(ax_lig3D, ax_lig3D.cat)
        eq_bo = get_bond_order(eq_lig3D.OBMol, eq_lig3D.cat, eq_lig3D)
        ax_bo = get_bond_order(ax_lig3D.OBMol, ax_lig3D.cat, ax_lig3D)

        if axial_ligs[0] in ['carbonyl', 'cn']:
            ax_bo = 3
        if equitorial_ligs[0] in ['carbonyl', 'cn']:
            eq_bo = 3
        eq_charge = eq_lig3D.OBMol.GetTotalCharge()
        ax_charge = ax_lig3D.OBMol.GetTotalCharge()

        # preprocess:
        sum_delen = (2.0) * ax_EN + (4.0) * eq_EN
        if abs(eq_EN) > abs(ax_EN):
            max_delen = eq_EN
        else:
            max_delen = ax_EN
        alpha = 0.2  # default for B3LYP
        if args.exchange:
            try:
                if float(args.exchange) > 1:
                    alpha = float(args.exchange) / 100  # if given as %
                elif float(args.exchange) <= 1:
                    alpha = float(args.exchange)
            except:
                print('cannot case exchange argument as a float, using 20%')

        if args.debug:
            print(('ax_bo', ax_bo))
            print(('eq_bo', eq_bo))
            print(('ax_dent', ax_dent))
            print(('eq_dent', eq_dent))
            print(('ax_charge', ax_charge))
            print(('eq_charge', eq_charge))
            print(('sum_delen', sum_delen))
            print(('max_delen', max_delen))
            print(('ax_type', ax_type))
            print(('eq_type', eq_type))
            print(('ax_ki', ax_ki))
            print(('eq_ki', eq_ki))

    if valid:

        nn_excitation = [
            0,
            0,
            0,
            0,
            0,  # metals co/cr/fe/mn/ni                 #1-5
            ox,
            alpha,
            ax_charge,
            eq_charge,  # ox/alpha/axlig charge/eqlig charge #6-9
            ax_dent,
            eq_dent,  # ax_dent/eq_dent/ #10-11
            0,
            0,
            0,
            0,  # axlig_connect: Cl,N,O,S #12 -15
            0,
            0,
            0,
            0,  # eqliq_connect: Cl,N,O,S #16-19
            sum_delen,
            max_delen,  # mdelen, maxdelen #20-21
            ax_bo,
            eq_bo,  # axlig_bo, eqliq_bo #22-23
            ax_ki,
            eq_ki
        ]  # axlig_ki, eqliq_kii #24-25
        slope_excitation = [
            0,
            0,
            0,
            0,
            0,  # metals co/cr/fe/mn/ni                 #1-5
            ox,
            ax_charge,
            eq_charge,  # ox/axlig charge/eqlig charge #6-8
            ax_dent,
            eq_dent,  # ax_dent/eq_dent/ #9-10
            0,
            0,
            0,
            0,  # axlig_connect: Cl,N,O,S #11 -14
            0,
            0,
            0,
            0,  # eqliq_connect: Cl,N,O,S #15-18
            sum_delen,
            max_delen,  # mdelen, maxdelen #19-20
            ax_bo,
            eq_bo,  # axlig_bo, eqliq_bo #21-22
            ax_ki,
            eq_ki
        ]  # axlig_ki, eqliq_kii #23-24

    # print(slope_excitation)
    # print('\n')
    # discrete variable encodings
    if valid:
        valid, nn_excitation = metal_corrector(nn_excitation, this_metal)
        valid, slope_excitation = metal_corrector(slope_excitation, this_metal)
    #
    if valid:
        valid, nn_excitation = ax_lig_corrector(nn_excitation, ax_type)
        valid, slope_excitation = ax_lig_corrector(slope_excitation, ax_type)

# print('ax_cor',valid)
#  print('start eq check')
    if valid:
        valid, nn_excitation = eq_lig_corrector(nn_excitation, eq_type)
        valid, slope_excitation = eq_lig_corrector(slope_excitation, eq_type)

#  print('eq_cor',valid)
#
# print(slope_excitation)
#print('excitations: ')
# print(nn_excitation)
# print(slope_excitation)

    if valid:
        print(
            "******************************************************************"
        )
        print(
            "************** ANN is engaged and advising on spin ***************"
        )
        print(
            "************** and metal-ligand bond distances    ****************"
        )
        print(
            "******************************************************************"
        )
        if high_spin:
            print(('You have selected a high-spin state, s = ' + str(spin)))
        else:
            print(('You have selected a low-spin state, s = ' + str(spin)))
        # test Euclidean norm to training data distance
        train_dist, best_row = find_eu_dist(nn_excitation)
        ANN_trust = max(0.01, 1.0 - train_dist)

        ANN_attributes.update({'ANN_dist_to_train': train_dist})
        ANN_attributes.update({'ANN_closest_train': best_row})
        print(('distance to training data is ' + "{0:.2f}".format(train_dist) +
               ' ANN trust: ' + "{0:.2f}".format(ANN_trust)))
        print((' with closest training row ' + best_row[:-2] + ' at  ' +
               str(best_row[-2:]) + '% HFX'))
        ANN_trust = 'not set'
        if float(train_dist) < 0.25:
            print('ANN results should be trustworthy for this complex ')
            ANN_trust = 'high'
        elif float(train_dist) < 0.75:
            print('ANN results are probably useful for this complex ')
            ANN_trust = 'medium'
        elif float(train_dist) < 1.0:
            print(
                'ANN results are fairly far from training data, be cautious ')
            ANN_trust = 'low'
        elif float(train_dist) > 1.0:
            print('ANN results are too far from training data, be cautious ')
            ANN_trust = 'very low'
        ANN_attributes.update({'ANN_trust': ANN_trust})
        # engage ANN
        delta = 0

        delta, scaled_excitation = get_splitting(nn_excitation)
        # report to stdout
        if delta[0] < 0 and not high_spin:
            if abs(delta[0]) > 5:
                print(
                    'warning, ANN predicts a high spin ground state for this complex'
                )
            else:
                print(
                    'warning, ANN predicts a near degenerate ground state for this complex'
                )
        if delta[0] >= 0 and high_spin:
            if abs(delta[0]) > 5:
                print(
                    'warning, ANN predicts a low spin ground state for this complex'
                )
            else:
                print(
                    'warning, ANN predicts a near degenerate ground state for this complex'
                )
        print(("ANN predicts a spin splitting (HS - LS) of " +
               "{0:.2f}".format(float(delta[0])) + ' kcal/mol at ' +
               "{0:.0f}".format(100 * alpha) + '% HFX'))
        ANN_attributes.update({'pred_split_ HS_LS': delta[0]})
        # reparse to save attributes
        ANN_attributes.update({'This spin': spin})
        if delta[0] < 0 and (abs(delta[0]) > 5):
            ANN_attributes.update({'ANN_ground_state': spin_ops[1]})
        elif delta[0] > 0 and (abs(delta[0]) > 5):
            ANN_attributes.update({'ANN_ground_state': spin_ops[0]})
        else:
            ANN_attributes.update({'ANN_gound_state': 'dgen ' + str(spin_ops)})

        r_ls = get_ls_dist(nn_excitation)
        r_hs = get_hs_dist(nn_excitation)
        if not high_spin:
            r = r_ls
        else:
            r = r_hs

        print(('ANN bond length is predicted to be: ' +
               "{0:.2f}".format(float(r)) + ' angstrom'))
        ANN_attributes.update({'ANN_bondl': len(batslist) * [r[0]]})
        print(('ANN low spin bond length is predicted to be: ' +
               "{0:.2f}".format(float(r_ls)) + ' angstrom'))
        print(('ANN high spin bond length is predicted to be: ' +
               "{0:.2f}".format(float(r_hs)) + ' angstrom'))

        # use ANN to predict fucntional sensitivty
        HFX_slope = 0
        HFX_slope = get_slope(slope_excitation)
        print(('Predicted HFX exchange sensitivity is : ' +
               "{0:.2f}".format(float(HFX_slope)) + ' kcal/HFX'))
        ANN_attributes.update({'ANN_slope': HFX_slope})
        print(
            "*******************************************************************"
        )
        print(
            "************** ANN complete, saved in record file *****************"
        )
        print(
            "*******************************************************************"
        )

    if not valid and not ANN_reason:
        ANN_reason = ' uncaught rejection (see sdout/stderr)'
    return valid, ANN_reason, ANN_attributes
Ejemplo n.º 6
0
def constrgen(rundir, args, globs):
    emsg = False
    # load global variables
    licores = getlicores()
    print('Random generation started..\n\n')
    # if ligand constraint apply it now
    ligs0 = []
    ligocc0 = []
    coord = False if not args.coord else int(args.coord)
    if args.gui:
        args.gui.iWtxt.setText('\n----------------------------------------------------------------------------------\n' +
                               'Random generation started\nGenerating ligand combinations.\n\n'+args.gui.iWtxt.toPlainText())
        args.gui.app.processEvents()
    if args.lig:
        for i, l in enumerate(args.lig):
            ligs0.append(l)
            ligentry, emsg = lig_load(l)  # check ligand
            # update ligand
            if ligentry:
                args.lig[i] = ligentry.name
            if emsg:
                return False, emsg
            if args.ligocc:
                if len(args.ligocc) < i and len(args.lig) == 1:
                    args.ligocc.append(coord)
                elif len(args.ligocc) < i:
                    args.ligocc.append(1)
            else:
                args.ligocc = []
                if len(args.lig) == 1:
                    args.ligocc.append(coord)
                else:
                    args.ligocc.append(1)
            ligocc0.append(args.ligocc[i])
            if args.lignum:
                args.lignum = str(int(args.lignum) - 1)
            # check for smiles
            if not ligentry.denticity:
                if args.smicat and len(args.smicat) >= i and args.smicat[i]:
                    ligentry.denticity = len(args.smicat[i])
                else:
                    ligentry.denticity = 1
            if coord:
                coord -= int(args.ligocc[i])*ligentry.denticity
            licores.pop(l, None)  # remove from dictionary
    # check for ligand groups
    licoresnew = dict()
    if args.liggrp and 'all' != args.liggrp.lower():
        for key in list(licores.keys()):
            if args.liggrp.lower() in licores[key][3]:
                if not args.ligctg or args.ligctg.lower() == 'all':
                    licoresnew[key] = licores[key]
                elif args.ligctg and args.ligctg.lower() in licores[key][3]:
                    licoresnew[key] = licores[key]
        licores = licoresnew
    # remove aminoacids
    licoresnew = dict()
    for key in list(licores.keys()):
        if 'aminoacid' not in licores[key][3]:
            licoresnew[key] = licores[key]
    licores = licoresnew
    # get a sample of these combinations
    samps = getconstsample(int(args.rgen[0]), args, licores, coord)
    if len(samps) == 0:
        if coord == 0:
            args.lig = [a for a in ligs0]
            args.ligocc = [int(a) for a in ligocc0]
            # run structure generation
            emsg = rungen(rundir, args, False, globs)
        else:
            if args.gui:
                from molSimplify.Classes.mWidgets import mQDialogErr
                qqb = mQDialogErr(
                    'Error', 'No suitable ligand sets were found for random generation. Exiting...')
                qqb.setParent(args.gui.wmain)
            else:
                emsg = 'No suitable ligand sets were found for random generation. Exiting...'
                print(
                    'No suitable ligand sets were found for random generation. Exiting...\n\n')
            return args, emsg
    # loop over samples
    for combo in samps:
        args.lig = [a for a in ligs0]
        args.ligocc = [int(a) for a in ligocc0]
        for cj in set(combo):
            lcount = Counter(combo)
            rocc = lcount[cj]
            args.lig.append(list(licores.keys())[cj])
            args.ligocc.append(rocc)
        # check for keep Hydrogens
        for iiH in range(len(ligs0), len(args.lig)):
            opt = True if args.rkHs else False
            if args.keepHs and len(args.keepHs) > iiH:
                args.keepHs[iiH] = opt
            elif args.keepHs:
                args.keepHs.append(opt)
            else:
                args.keepHs = [opt]
        emsg = rungen(rundir, args, False, globs)  # run structure generation
    return args, emsg
Ejemplo n.º 7
0
def rungen(rundir, args, chspfname, globs, write_files=True):
    try:
        from Classes.mWidgets import qBoxFolder
        from Classes.mWidgets import mQDialogInf
        from Classes.mWidgets import mQDialogErr
    except ImportError:
        args.gui = False
    emsg = False
    globs.nosmiles = 0  # reset smiles ligands for each run
    # check for specified ligands/functionalization
    ligocc = []
    # check for files specified for multiple ligands
    mligs, catoms = [False], [False]
    if args.lig is not None:
        if '.smi' in args.lig[0]:
            ligfilename = args.lig[0].split('.')[0]
        if args.lig:
            mligs, catoms, multidx = checkmultilig(args.lig)
        if args.debug:
            print(('after checking for mulitple ligs, we found  ' +
                   str(multidx) + ' ligands'))
    # save initial
    smicat0 = [ss for ss in args.smicat] if args.smicat else False
    # loop over ligands
    for mcount, mlig in enumerate(mligs):
        args.smicat = [ss for ss in smicat0] if smicat0 else False
        args.checkdir, skip = False, False  # initialize flags
        if len(mligs) > 0 and mligs[0]:
            args.lig = mlig  # get combination
            if multidx != -1:
                if catoms[multidx][mcount]:
                    ssatoms = catoms[multidx][mcount].split(',')
                    lloc = [int(scat)-1 for scat in ssatoms]
                    # append connection atoms if specified in smiles
                    if args.smicat and len(args.smicat) > 0:
                        for i in range(len(args.smicat), multidx):
                            args.smicat.append([])
                    else:
                        args.smicat = [lloc]
                    args.smicat[multidx] = lloc
        if (args.lig):
            ligands = args.lig
            if (args.ligocc):
                ligocc = args.ligocc
            else:
                ligocc = ['1']
            for i in range(len(ligocc), len(ligands)):
                ligocc.append('1')
            lig = ''
            for i, l in enumerate(ligands):
                ligentry, emsg = lig_load(l)
                # update ligand
                if ligentry:
                    ligands[i] = ligentry.name
                    args.lig[i] = ligentry.name
                if emsg:
                    skip = True
                    break
                if ligentry.ident == 'smi':
                    ligentry.ident += str(globs.nosmiles)
                    globs.nosmiles += 1
                    if args.sminame:
                        if len(args.sminame) > int(ligentry.ident[-1]):
                            ligentry.ident = args.sminame[globs.nosmiles-1][0:3]
                lig += ''.join("%s%s" % (ligentry.ident, ligocc[i]))
        else:
            ligands = []
            lig = ''
            ligocc = ''
    # fetch smart name
        fname = name_complex(rundir, args.core, args.geometry, ligands, ligocc,
                             mcount, args, nconf=False, sanity=False, bind=args.bind, bsmi=args.nambsmi)
        if args.tsgen:
            substrate = args.substrate
            subcatoms = ['multiple']
            if args.subcatoms:
                subcatoms = args.subcatoms
            mlig = args.mlig
            mligcatoms = args.mligcatoms
            fname = name_ts_complex(rundir, args.core, args.geometry, ligands, ligocc, substrate, subcatoms,
                                    mlig, mligcatoms, mcount, args, nconf=False, sanity=False, bind=args.bind, bsmi=args.nambsmi)
        if globs.debug:
            print(('fname is ' + str(fname)))
        rootdir = fname
        # check for charges/spin
        rootcheck = False
        if (chspfname):
            rootcheck = rootdir
            rootdir = rootdir + '/'+chspfname
        if (args.suff):
            rootdir += args.suff
        # check for mannual overwrite of
        # directory name
        if args.jobdir:
            rootdir = rundir + args.jobdir
            # check for top directory
        if rootcheck and os.path.isdir(rootcheck) and not args.checkdirt and not skip:
            args.checkdirt = True
            if not args.rprompt:
                flagdir = get_input('\nDirectory '+rootcheck +
                                ' already exists. Keep both (k), replace (r) or skip (s) k/r/s: ')
                if 'k' in flagdir.lower():
                    flagdir = 'keep'
                elif 's' in flagdir.lower():
                    flagdir = 'skip'
                else:
                    flagdir = 'replace'
            else:
                #qqb = qBoxFolder(args.gui.wmain,'Folder exists','Directory '+rootcheck+' already exists. What do you want to do?')
                #flagdir = qqb.getaction()
                flagdir = 'replace'
                # replace existing directory
            if (flagdir == 'replace'):
                shutil.rmtree(rootcheck)
                os.mkdir(rootcheck)
            # skip existing directory
            elif flagdir == 'skip':
                skip = True
            # keep both (default)
            else:
                ifold = 1
                while glob.glob(rootdir+'_'+str(ifold)):
                    ifold += 1
                    rootcheck += '_'+str(ifold)
                    os.mkdir(rootcheck)
        elif rootcheck and (not os.path.isdir(rootcheck) or not args.checkdirt) and not skip:
            if globs.debug:
                print(('rootcheck is  ' + str(rootcheck)))
            args.checkdirt = True
            try:
                os.mkdir(rootcheck)
            except:
                print(('Directory '+rootcheck+' can not be created. Exiting..\n'))
                return
            # check for actual directory
        if os.path.isdir(rootdir) and not args.checkdirb and not skip and not args.jobdir:
            args.checkdirb = True
            if not args.rprompt:
                flagdir = get_input(
                    '\nDirectory '+rootdir + ' already exists. Keep both (k), replace (r) or skip (s) k/r/s: ')
                if 'k' in flagdir.lower():
                    flagdir = 'keep'
                elif 's' in flagdir.lower():
                    flagdir = 'skip'
                else:
                    flagdir = 'replace'
            else:
                #qqb = qBoxFolder(args.gui.wmain,'Folder exists','Directory '+rootdir+' already exists. What do you want to do?')
                #flagdir = qqb.getaction()
                flagdir = 'replace'
            # replace existing directory
            if (flagdir == 'replace'):
                shutil.rmtree(rootdir)
                os.mkdir(rootdir)
            # skip existing directory
            elif flagdir == 'skip':
                skip = True
            # keep both (default)
            else:
                ifold = 1
                while glob.glob(rootdir+'_'+str(ifold)):
                    ifold += 1
                rootdir += '_'+str(ifold)
                os.mkdir(rootdir)
        elif not os.path.isdir(rootdir) or not args.checkdirb and not skip:
            if not os.path.isdir(rootdir):
                if write_files:
                    args.checkdirb = True
                    os.mkdir(rootdir)
            ####################################
            ############ GENERATION ############
            ####################################
        if not skip:
            # check for generate all
            if args.genall:
                tstrfiles = []
                # generate xyz with FF and trained ML
                args.ff = 'mmff94'
                args.ffoption = 'ba'
                args.MLbonds = False
                strfiles, emsg, this_diag = structgen(
                    args, rootdir, ligands, ligocc, globs, mcount, write_files=write_files)
                for strf in strfiles:
                    tstrfiles.append(strf+'FFML')
                    os.rename(strf+'.xyz', strf+'FFML.xyz')
                # generate xyz with FF and covalent
                args.MLbonds = ['c' for i in range(0, len(args.lig))]
                strfiles, emsg, this_diag = structgen(
                    args, rootdir, ligands, ligocc, globs, mcount, write_files=write_files)
                for strf in strfiles:
                    tstrfiles.append(strf+'FFc')
                    os.rename(strf+'.xyz', strf+'FFc.xyz')
                args.ff = False
                args.ffoption = False
                args.MLbonds = False
                # generate xyz without FF and trained ML
                strfiles, emsg, this_diag = structgen(
                    args, rootdir, ligands, ligocc, globs, mcount, write_files=write_files)
                for strf in strfiles:
                    tstrfiles.append(strf+'ML')
                    os.rename(strf+'.xyz', strf+'ML.xyz')
                args.MLbonds = ['c' for i in range(0, len(args.lig))]
                # generate xyz without FF and covalent ML
                strfiles, emsg, this_diag = structgen(
                    args, rootdir, ligands, ligocc, globs, mcount, write_files=write_files)
                for strf in strfiles:
                    tstrfiles.append(strf+'c')
                    os.rename(strf+'.xyz', strf+'c.xyz')
                strfiles = tstrfiles
            else:
                # generate xyz files
                strfiles, emsg, this_diag = structgen(
                    args, rootdir, ligands, ligocc, globs, mcount, write_files=write_files)
            # generate QC input files
            if args.qccode and (not emsg) and write_files:
                if args.charge and (isinstance(args.charge, list)):
                    args.charge = args.charge[0]
                if args.spin and (isinstance(args.spin, list)):
                    args.spin = args.spin[0]
                if args.qccode.lower() in 'terachem tc Terachem TeraChem TERACHEM TC':
                    jobdirs = multitcgen(args, strfiles)
                    print('TeraChem input files generated!')
                elif 'gam' in args.qccode.lower():
                    jobdirs = multigamgen(args, strfiles)
                    print('GAMESS input files generated!')
                elif 'qch' in args.qccode.lower():
                    jobdirs = multiqgen(args, strfiles)
                    print('QChem input files generated!')
                elif 'orc' in args.qccode.lower():
                    jobdirs = multiogen(args, strfiles)
                    print('ORCA input files generated!')
                elif 'molc' in args.qccode.lower():
                    jobdirs = multimolcgen(args, strfiles)
                    print('MOLCAS input files generated!')
                else:
                    print(
                        'Only TeraChem, GAMESS, QChem, ORCA, MOLCAS are supported right now.\n')
            # check molpac
            if args.mopac and (not emsg) and write_files:
                print('Generating MOPAC input')
                if globs.debug:
                    print(strfiles)
                jobdirs = mlpgen(args, strfiles, rootdir)
            # generate jobscripts
            if args.jsched and (not emsg) and (not args.reportonly) and (write_files):
                if args.jsched in 'SBATCH SLURM slurm sbatch':
                    slurmjobgen(args, jobdirs)
                    print('SLURM jobscripts generated!')
                elif args.jsched in 'SGE Sungrid sge':
                    sgejobgen(args, jobdirs)
                    print('SGE jobscripts generated!')

            elif multidx != -1:  # if ligand input was a list of smiles strings, write good smiles strings to separate list
                try:
                    f = open(ligfilename+'-good.smi', 'a')
                    f.write(args.lig[0])
                    f.close()
                except:
                    0
        elif not emsg:
            if args.gui:
                qq = mQDialogInf('Folder skipped', 'Folder ' +
                                 rootdir+' was skipped.')
                qq.setParent(args.gui.wmain)
            else:
                print(('Folder '+rootdir+' was skipped..\n'))
    if write_files:
        return emsg # Default behavior
    else:
        return strfiles, emsg, this_diag # Assume that user wants these if they're not writing files
Ejemplo n.º 8
0
def decorate_ligand(args, ligand_to_decorate, decoration, decoration_index):
    # structgen depends on decoration_manager, and decoration_manager depends on structgen.ffopt
    # Thus, this import needs to be placed here to avoid a circular dependence
    from molSimplify.Scripts.structgen import ffopt
    # INPUT
    #   - args: placeholder for input arguments
    #   - ligand_to_decorate: mol3D ligand
    #   - decoration: list of smiles/decorations
    #   - decoration_index: list of ligand atoms to replace
    # OUTPUT
    #   - new_ligand: built ligand
    #   - complex3D: list of all mol3D ligands and core
    #   - emsg: error messages
    #if args.debug:
    #    print  'decorating ligand'
    lig = ligand_to_decorate
    ## reorder to ensure highest atom index
    ## removed first
    sort_order = [
        i[0] for i in sorted(enumerate(decoration_index), key=lambda x: x[1])
    ]
    sort_order = sort_order[::-1]  ## reverse

    decoration_index = [decoration_index[i] for i in sort_order]
    decoration = [decoration[i] for i in sort_order]
    if args.debug:
        print(('decoration_index  is  ' + str(decoration_index)))
    licores = getlicores()
    if not isinstance(lig, mol3D):
        lig, emsg = lig_load(lig, licores)
    else:
        lig.convert2OBMol()
        lig.charge = lig.OBMol.GetTotalCharge()
    lig.convert2mol3D()  # convert to mol3D

    ## create new ligand
    merged_ligand = mol3D()
    merged_ligand.copymol3D(lig)
    for i, dec in enumerate(decoration):
        print(('** decoration number ' + str(i) + ' attaching ' + dec +
               ' at site ' + str(decoration_index[i]) + '**\n'))
        dec, emsg = lig_load(dec, licores)
        # dec.OBMol.AddHydrogens()
        dec.convert2mol3D()  # convert to mol3D
        if args.debug:
            print(i)
            print(decoration_index)

            print((merged_ligand.getAtom(decoration_index[i]).symbol()))
            print((merged_ligand.getAtom(decoration_index[i]).coords()))
            merged_ligand.writexyz('basic.xyz')
        #dec.writexyz('dec' + str(i) + '.xyz')
        Hs = dec.getHsbyIndex(0)
        if len(Hs) > 0 and (not len(dec.cat)):
            dec.deleteatom(Hs[0])
            dec.charge = dec.charge - 1

        #dec.writexyz('dec_noH' + str(i) + '.xyz')
        if len(dec.cat) > 0:
            decind = dec.cat[0]
        else:
            decind = 0
        dec.alignmol(dec.getAtom(decind),
                     merged_ligand.getAtom(decoration_index[i]))
        r1 = dec.getAtom(decind).coords()
        r2 = dec.centermass()  # center of mass
        rrot = r1
        decb = mol3D()
        decb.copymol3D(dec)
        ####################################
        # center of mass of local environment (to avoid bad placement of bulky ligands)
        auxmol = mol3D()
        for at in dec.getBondedAtoms(decind):
            auxmol.addAtom(dec.getAtom(at))
        if auxmol.natoms > 0:
            r2 = auxmol.centermass()  # overwrite global with local centermass
            ####################################
            # rotate around axis and get both images
            theta, u = rotation_params(merged_ligand.centermass(), r1, r2)
            #print('u = ' + str(u) + ' theta  = ' + str(theta))
            dec = rotate_around_axis(dec, rrot, u, theta)
            if args.debug:
                dec.writexyz('dec_ARA' + str(i) + '.xyz')
            decb = rotate_around_axis(decb, rrot, u, theta - 180)
            if args.debug:
                decb.writexyz('dec_ARB' + str(i) + '.xyz')
            d1 = distance(dec.centermass(), merged_ligand.centermass())
            d2 = distance(decb.centermass(), merged_ligand.centermass())
            dec = dec if (d2 < d1) else decb  # pick best one
        #####################################
        # check for linear molecule
        auxm = mol3D()
        for at in dec.getBondedAtoms(decind):
            auxm.addAtom(dec.getAtom(at))
        if auxm.natoms > 1:
            r0 = dec.getAtom(decind).coords()
            r1 = auxm.getAtom(0).coords()
            r2 = auxm.getAtom(1).coords()
            if checkcolinear(r1, r0, r2):
                theta, urot = rotation_params(
                    r1,
                    merged_ligand.getAtom(decoration_index[i]).coords(), r2)
                theta = vecangle(
                    vecdiff(
                        r0,
                        merged_ligand.getAtom(decoration_index[i]).coords()),
                    urot)
                dec = rotate_around_axis(dec, r0, urot, theta)

        ## get the default distance between atoms in question
        connection_neighbours = merged_ligand.getAtom(
            merged_ligand.getBondedAtomsnotH(decoration_index[i])[0])
        new_atom = dec.getAtom(decind)
        target_distance = connection_neighbours.rad + new_atom.rad
        position_to_place = vecdiff(new_atom.coords(),
                                    connection_neighbours.coords())
        old_dist = norm(position_to_place)
        missing = (target_distance - old_dist) / 2
        dec.translate([missing * position_to_place[j] for j in [0, 1, 2]])

        r1 = dec.getAtom(decind).coords()
        u = vecdiff(r1, merged_ligand.getAtom(decoration_index[i]).coords())
        dtheta = 2
        optmax = -9999
        totiters = 0
        decb = mol3D()
        decb.copymol3D(dec)
        # check for minimum distance between atoms and center of mass distance
        while totiters < 180:
            #print('totiters '+ str(totiters))
            dec = rotate_around_axis(dec, r1, u, dtheta)
            d0 = dec.mindist(
                merged_ligand)  # try to maximize minimum atoms distance
            d0cm = dec.distance(
                merged_ligand)  # try to maximize center of mass distance
            iteropt = d0cm + d0  # optimization function
            if (iteropt > optmax):  # if better conformation, keep
                decb = mol3D()
                decb.copymol3D(dec)
                optmax = iteropt
                #temp = mol3D()
                #temp.copymol3D(merged_ligand)
                #temp.combine(decb)
                #temp.writexyz('opt_iter_'+str(totiters)+'.xyz')
                #print('new max! ' + str(iteropt) )
            totiters += 1
        dec = decb
        if args.debug:
            dec.writexyz('dec_aligned' + str(i) + '.xyz')
            print(('natoms before delete ' + str(merged_ligand.natoms)))
            print(('obmol before delete at  ' + str(decoration_index[i]) +
                   ' is ' + str(merged_ligand.OBMol.NumAtoms())))
        ## store connectivity for deleted H
        BO_mat = merged_ligand.populateBOMatrix()
        row_deleted = BO_mat[decoration_index[i]]
        bonds_to_add = []

        # find where to put the new bonds ->>> Issue here.
        for j, els in enumerate(row_deleted):
            if els > 0:
                # if there is a bond with an atom number
                # before the deleted atom, all is fine
                # else, we subtract one as the row will be be removed
                if j < decoration_index[i]:
                    bond_partner = j
                else:
                    bond_partner = j - 1
                if len(dec.cat) > 0:
                    bonds_to_add.append(
                        (bond_partner, (merged_ligand.natoms - 1) + dec.cat[0],
                         els))
                else:
                    bonds_to_add.append(
                        (bond_partner, merged_ligand.natoms - 1, els))

        ## perfrom delete
        merged_ligand.deleteatom(decoration_index[i])

        merged_ligand.convert2OBMol()
        if args.debug:
            merged_ligand.writexyz('merged del ' + str(i) + '.xyz')
        ## merge and bond
        merged_ligand.combine(dec, bond_to_add=bonds_to_add)
        merged_ligand.convert2OBMol()

        if args.debug:
            merged_ligand.writexyz('merged' + str(i) + '.xyz')
            merged_ligand.printxyz()
            print('************')

    merged_ligand.convert2OBMol()
    merged_ligand, emsg = ffopt('MMFF94', merged_ligand, [], 0, [], False, [],
                                100)
    BO_mat = merged_ligand.populateBOMatrix()
    if args.debug:
        merged_ligand.writexyz('merged_relaxed.xyz')
        print(BO_mat)
    return (merged_ligand)