Пример #1
0
def main(args):

    if args.ligand_type == 'ligand':
        init_args = ' '.join(['-extra_res_fa', args.ligand_name])
        lig_name = args.ligand_name.split('/')[-1].strip('.params')
    else:
        init_args = ''

    pr.init(init_args)
    sf = pr.get_fa_scorefxn()
    sf.add_weights_from_file('ref2015')

    pose = pr.pose_from_pdb(args.input_pdb)
    sf(pose)

    print_notice("Scaffold protein Loaded Successfully!")
    print_notice("Scaffold protein has" + str(pose.total_residue()) +
                 "residues.")

    if args.symmetric_file:
        sfsm = SetupForSymmetryMover(args.symmetric_file)
        sfsm.apply(pose)

    #Importing list of residues if the ligand is a protein
    if args.ligand_type == 'protein':
        ligres = ResidueIndexSelector(args.residue_set)
    #Targeting the ligand if the ligand isn't protein
    elif args.ligand_type == 'ligand':
        ligres = ResidueNameSelector()
        ligres.set_residue_name3(lig_name)

    print_notice("Ligand found at resnum: " + \
            str(get_residues_from_subset(ligres.apply(pose))) )

    #Setting the proteins not in the ligand
    not_ligand = NotResidueSelector()
    not_ligand.set_residue_selector(ligres)

    #Setting the protein considered part of the ligand
    ligand_distant_contacts = InterGroupInterfaceByVectorSelector()
    ligand_distant_contacts.group1_selector(ligres)
    ligand_distant_contacts.group2_selector(not_ligand)
    ligand_distant_contacts.cb_dist_cut(2.5 * float(args.radius))
    ligand_distant_contacts.nearby_atom_cut(float(args.radius))

    #Test set: ClosecontactResidueSelector
    close_contacts = pr.rosetta.core.select.residue_selector.CloseContactResidueSelector(
    )
    close_contacts.central_residue_group_selector(ligres)
    close_contacts.threshold(float(args.radius))

    all_contacts = OrResidueSelector()
    all_contacts.add_residue_selector(close_contacts)
    all_contacts.add_residue_selector(ligand_distant_contacts)

    non_lig_residues = AndResidueSelector()
    non_lig_residues.add_residue_selector(all_contacts)
    non_lig_residues.add_residue_selector(not_ligand)

    #Collecting the residues from the subset
    neighbor_residues = get_residues_from_subset(non_lig_residues.apply(pose))
    pdb_residues = []
    for residue in neighbor_residues:
        print(pose.pdb_info().pose2pdb(residue))
        resid = pose.pdb_info().pose2pdb(residue).split()
        pdb_residues.append(''.join(resid))

    print_notice("Ligand found, neighbor residues are: " +
                 ', '.join([x for x in pdb_residues]))
    print_notice("Ligand found, total neighbor residues is " +
                 str(len(pdb_residues)))

    #Removing residues in the REMARKS section
    remove_set = []
    f = open(args.input_pdb, 'r')
    for line in f.readlines():
        if 'REMARK' in line:
            items = [x for x in line.split(' ') if x != '']
            residue_set = [int(items[6]), int(items[11])]
            for resi in residue_set:
                if resi not in remove_set:
                    remove_set.append(resi)

    residue_final_set = []
    for resi in pdb_residues:
        idx = int(resi[0:-1])
        if idx not in remove_set:
            residue_final_set.append(resi)
    #Final list for the designable residues
    residue_final_set.append('0')
    print_notice("Neighbor residues cleaned \n \n Residues are: " +\
            ', '.join([x for x in residue_final_set]))

    if args.out_file:
        out_name = args.out_file
    else:
        out_name = args.input_pdb.strip('.pdb') + '_lig.pos'

    f = open(out_name, "w")
    for x in residue_final_set:
        f.write(x + '\n')
    f.close
    print("emd182::Wrote ligand position residues of pdb file " +
          args.input_pdb + " to filename: " + out_name)
Пример #2
0
def main(args):

    #Determining if the ligand should be removed or not
    #Need to be moved here as checking if the lig remains is
    #Imperative for how the CST file is dealt with.
    out_ligand_file = 'yeslig'
    if args.remove_ligand:
        args.rigid_ligand = False
        out_ligand_file = 'nolig'

    params = [args.unnatural]
    uaa = params[0].split('/')[-1].strip(".params")
    if args.ligand_type == 'ligand':
        params.append(args.ligand_name)
        lig_name = args.ligand_name.split('/')[-1].strip('.params')

    init_args = ' '.join(['-run:preserve_header', '-extra_res_fa'] + params)

    #Adding enzdes constraint file - and editing it - if necessary
    if args.enzdes_constraint_file:
        if out_ligand_file == 'nolig':
            if args.ligand_type == 'protein':
                lig_search = args.residue_set
            elif args.ligand_type == 'ligand':
                lig_search = args.ligand_name
            args.input_pdb, args.enzdes_constraint_file = enzdes_constraint_eliminator(\
                    args.input_pdb, args.enzdes_constraint_file, ligand=lig_search )

        #init_args = init_args + " -enzdes::cstfile " + str(args.enzdes_constraint_file)


#Starting up rosetta with the appropriate params files -
#-- need both unnatural aa and ligand file to be added (if the ligand isn't a protein).
    print_out("emd182::Starting rosetta with the following parameters: " +
              init_args)
    pr.init(init_args)
    file_options = pr.rosetta.core.io.StructFileRepOptions()
    file_options.set_preserve_header(bool(True))

    pose = pr.pose_from_pdb(args.input_pdb)

    #Residue selection
    if args.residue_number != '0':
        delta_resi = ResidueIndexSelector(args.residue_number)

    #Determining if the interacting ligand is a protein or small molecule
    #and selecting the appropriate residues
    if args.ligand_type == 'protein':
        ligand = ResidueIndexSelector(args.residue_set)
    elif args.ligand_type == 'ligand':
        ligand = ResidueNameSelector()
        ligand.set_residue_name3(lig_name)

    #Adding the unnatural at the mutation site
    print_out("Loading Unnatural " + uaa + \
        " onto residue number " + args.residue_number )

    if args.residue_number in selector_to_vector(ligand, pose):
        print_out(
            "Selected residue number IS a part of the ligand. Don't do that. \
                Chosen residue number: " + str(args.residue_number))
    #Setting up Mutation function on the correct residue, so long as the
    #residue isn't #0
    if args.residue_number != '0':
        mutater = pr.rosetta.protocols.simple_moves.MutateResidue()
        mutating_residue = ResidueIndexSelector(args.residue_number)
        mutater.set_selector(mutating_residue)
        mutater.set_res_name(uaa)
        print_out("emd182::loading residue to mutate into: " +
                  str(args.residue_number))

    lig_vector = selector_to_vector(ligand, pose)
    if args.residue_number != '0':
        if args.ligand_type == 'protein' and delta_resi in lig_vector:
            print_out("Selected residue for mutation is part of input selection: " + \
                    delta_resi + " is selected, but is contained in " \
                    + str(lig_vector))
            print_out("Exiting the python script")
            sys.exit()

    print_out("emd182::Start mutations and relaxation script " \
        + str(args.nstruct) + " times.")

    if args.residue_number == '0':
        args.nstruct = 1

    for struct in range(0, args.nstruct):
        print_out(struct)
        mutant_pose = pr.Pose(pose)
        #Residue selection
        if args.residue_number != '0':
            delta_resi = ResidueIndexSelector(args.residue_number)

        #apply mutation
        if args.residue_number != '0':
            mutater.apply(mutant_pose)
            if args.dihedral_constraint_atoms:
                dihedral_atoms = args.dihedral_constraint_atoms.split(';')
                dihedral_values = args.dihedral_constraint_degrees.split(';')
                dihedral_cst_stdev = args.dihedral_constraint_stdev.split(';')
                for dihedrals in range(len(dihedral_atoms)):
                    apply_dihedral_constraint(dihedral_atoms[dihedrals].split(','), \
                                delta_resi, mutant_pose, dihedral_values[dihedrals],\
                                dihedral_cst_stdev[dihedrals])

        move_map = build_move_map(True, True, True)

        #Scoring the pose
        sf = pr.rosetta.core.scoring.ScoreFunction()
        if args.symmdef_file:
            #Setting a residueindexselector to eliminate all extra aa post design
            pre_symm_ris = ResidueIndexSelector()
            all_aa = []
            for i in range(1, mutant_pose.total_residue() + 1):
                all_aa.append(''.join(
                    mutant_pose.pdb_info().pose2pdb(i).split()))
            pre_symm_ris.set_index(','.join(all_aa))
            #sys.exit()
            #Symmetrizing
            sfsm = SetupForSymmetryMover(args.symmdef_file)
            sfsm.apply(mutant_pose)
            sf = pr.rosetta.core.scoring.symmetry.SymmetricScoreFunction()
            sf.add_weights_from_file('ref2015_cst')
        else:
            sf = pr.rosetta.core.scoring.ScoreFunction()
        sf.add_weights_from_file('ref2015_cst')

        sf(mutant_pose)

        #apply mutation
        if args.residue_number != '0':
            mutater.apply(mutant_pose)
            if args.dihedral_constraint_atoms:
                dihedral_atoms = args.dihedral_constraint_atoms.split(';')
                dihedral_values = args.dihedral_constraint_degrees.split(';')
                dihedral_cst_stdev = args.dihedral_constraint_stdev.split(';')
                for dihedrals in range(len(dihedral_atoms)):
                    apply_dihedral_constraint(dihedral_atoms[dihedrals].split(','), \
                                delta_resi, mutant_pose, dihedral_values[dihedrals],\
                                dihedral_cst_stdev[dihedrals])

        #Making appropriate residue selections base on if the ligand will be
        #Removed or not, as well as if the ligand is rigid or not.
        residues_around_ligand = ligand_neighbor_selection(ligand, args.radius, \
                mutant_pose, bool(args.rigid_ligand) )

        if args.residue_number != '0':
            residues_around_mutant = ligand_neighbor_selection(mutating_residue, \
                    args.radius, mutant_pose, True)

        design_around_mutant = None
        #if design is turned on, will design around the mutant.
        #Can add more variability later
        if args.design and args.residue_number != '0':
            designing_residues = ligand_neighbor_selection(mutating_residue, \
                    args.design, mutant_pose, False)

        #Combining the repacking neighborhoods around the ligand and mutant
        if args.residue_number != '0':
            repacking_neighbors = residue_selection(residues_around_ligand, \
                    residues_around_mutant)
        else:
            repacking_neighbors = ResidueIndexSelector(residues_around_ligand)

        #Specifically converting residue selectors to vectors, and removing the
        #Appropriate residue ids from the lists of repacking or designing residues
        repacking_resids = selector_to_vector(repacking_neighbors, mutant_pose)
        if args.rigid_ligand:
            lig_resids = selector_to_vector(ligand, mutant_pose)
            for res in lig_resids:
                if res in repacking_resids:
                    repacking_resids.remove(res)
            if args.design:
                designing_resids = selector_to_vector(designing_residues,
                                                      mutant_pose)
                for res in lig_resids:
                    if res in designing_resids:
                        repacking_resids.remove(res)

        #If the remove-ligand is called, will remove the ligand from the mutant pose
        #Change 'remove-ligand' to 'translate region - makes for
        if args.remove_ligand:
            used_xyzs = []
            for chain in args.remove_ligand.split(','):
                x, y, z = random_direction_selector(used_xyzs)
                used_xyzs.append(np.array([x, y, z]))
                trans_vec = pr.rosetta.numeric.xyzVector_double_t(x, y, z)
                trans = pr.rosetta.protocols.rigid.RigidBodyTransMover(
                    trans_vec)
                jump_id = pr.rosetta.core.pose.get_jump_id_from_chain(
                    chain, mutant_pose)
                trans.rb_jump(jump_id)
                trans.step_size(500.0)
                trans.apply(mutant_pose)

        tf = task_factory_builder(repacking_residues=repacking_resids, \
                designing_residues=design_around_mutant)

        move_map = build_move_map(True, True, True)

        #turn on match constraints if needed:
        if args.enzdes_constraint_file:
            if not args.remove_ligand:
                print_out(
                    'The code will break if the constraint file is attached \
                        to the ligand that is being removed')
            apply_match_constraints(mutant_pose, args.enzdes_constraint_file)
            print('just checking')

        #turn off constraints if requested - default is on
        if not args.no_constraints:
            mutant_pose = coord_constrain_pose(mutant_pose)

        #Repack or minimize if selected
        if args.fast_relax:
            print_out("Running With Fast Relax")
            new_pose = fast_relax_mutant(mutant_pose, tf, move_map, sf)
        else:
            print_out("Running repack and minimize")
            new_pose = repack_and_minimize_mutant(mutant_pose, tf, move_map,
                                                  sf)

        #output the name of the file
        #Includes original pdb namne, residue number and uaa, nstruct number,
        #and if the ligand is included or not.
        base_pdb_filename = args.input_pdb.split('/')[-1].split(
            '_')[0].replace('.pdb', '')
        outname = '{}_{}_{}_{}_{}.pdb'.format(base_pdb_filename, \
                args.residue_number, uaa, out_ligand_file, str(struct))
        if args.residue_number == '0':
            outname = '{}_{}_min.pdb'.format(base_pdb_filename, \
                    out_ligand_file)
        if args.out_suffix:
            outname = outname.replace(".pdb", args.out_suffix + ".pdb")
        print_out("Outputting protein file as : " + outname)

        out_dir = out_directory(args.out_directory)
        outname = '/'.join([out_dir, outname])
        print_out("Writing Outputting protein file as : " + outname)

        #Symmetry needs to be broken to load the proteins in another
        #script to analyze them symmetrically.
        if args.symmdef_file:
            full_out = outname.replace('.pdb', 'sym.pdb')
            new_pose.dump_pdb(full_out)
            not_symm = NotResidueSelector(pre_symm_ris)
            remove_symmetrized_aa = DeleteRegionMover()
            remove_symmetrized_aa.set_residue_selector(not_symm)
            remove_symmetrized_aa.apply(new_pose)
            new_pose.dump_pdb(outname)
        else:
            new_pose.dump_pdb(outname)
Пример #3
0
def main(args):

    params = []
    if args.ligand_type == 'ligand':
        params.append(args.ligand_name)
        lig_name = args.ligand_name.split('/')[-1].strip('.params')
    if args.unnatural:
        params.append(args.cis_params_file)
        params.append(args.trans_params_file)

    #setting up constraint file conditionals, as well as other init arguments
    cstfiles = {}
    if args.enzdes_cstfiles:
        manage_cst_files(cstfiles, args.enzdes_cstfiles)

    init_args = ' '.join(['-run:preserve_header', '-extra_res_fa'] + params)

    #init_args 'yeslig' and 'nolig' either contain or don't contain
    # the constraint files for the pdbs.

    if not args.previously_run_files:
        #Starting up rosetta with the appropriate params files -
        #Need to incorporate constraints somehow
        print_out("emd182::Starting rosetta with the following parameters: " \
                    + init_args)
        pr.init(init_args)

        #Setting up score functions
        sf = pr.rosetta.core.scoring.ScoreFunction()
        sfcst = pr.rosetta.core.scoring.ScoreFunction()
        sf.add_weights_from_file('ref2015')
        sfcst.add_weights_from_file('ref2015_cst')

        #Looking in the directory to determine what to group:
        pdb_list = glob('/'.join([args.source_directory, '*.pdb']))

        #Briefly setting up the wt poses
        wt_pose = {}
        for pdb_file_name in args.wt_pdb.split(','):
            print(pdb_file_name)
            pdb_list.remove(pdb_file_name)
            if 'yeslig' in pdb_file_name:
                #pr.init(init_args['yeslig'])
                wt_pose['yeslig'] = pr.pose_from_pdb(pdb_file_name)
            elif 'nolig' in pdb_file_name:
                #pr.init(init_args['yeslig'])
                wt_pose['nolig'] = pr.pose_from_pdb(pdb_file_name)
            else:
                print('Error, submitted wild type pdbs do not have yes or no ligand' + \
                        'name. Please submit properly named pdbs.')
                sys.exit()

        for wt_key in wt_pose.keys():
            sf(wt_pose[wt_key])

        print_out("Loaded and scored the wt_pose")

        #Residue selection for later analysis
        if args.ligand_type == 'protein':
            ligand = ResidueIndexSelector(args.residue_set)
        elif args.ligand_type == 'ligand':
            ligand = ResidueNameSelector()
            ligand.set_residue_name3(lig_name)

        pdb_dict = {}
        #Dictionary of pdb info - name, resi, uaa, cis/trans, ligand, nstruct
        group_set = {}
        #Dictionary of [residue number] : pdb file name for analysis

        print_out("Interpreting the list of pdbs in the directory: " + \
                    args.source_directory )
        #Sorting file names into groups and into dictionaries
        for pdb_file_name in pdb_list:
            pdb_dict[pdb_file_name] = file_interpreter_to_dict(pdb_file_name)
            resi = pdb_file_name.split('/')[-1].strip('.pdb').split('_')[1]
            if_list_key_not_in_dict(resi, group_set)
            group_set[resi].append(pdb_file_name)

        scored_set = {}  #Dictionary of scores at each residue index
        print_out('Taking groups of proteins and scoring them one by one.')
        #Taking protein groups and scoring them one by one.
        for residue_key in group_set.keys():
            print_out("Now on residue number: " + str(residue_key))
            cst_key = residue_key + "_cst"
            #Place poses in a dictionary with the key designated as the pdb_file_name.
            #Each item in a dictionary is a scored pose object.
            #Creates a new dictionary each time to allow for score analysis each round.
            resi_poses = dictionaried_poses(group_set[residue_key], sf,
                                            cstfiles)
            scored_set[residue_key] = scoring_poses(resi_poses, pdb_dict, sf)
            if cstfiles:
                scored_set[cst_key] = scoring_poses(resi_poses, pdb_dict,
                                                    sfcst)
            for filename in resi_poses.keys():
                if pdb_dict[filename]['contain_ligand']:
                    interaction_energy_addition(resi_poses[filename], ligand, \
                        sf, scored_set[residue_key][filename], \
                        pdb_dict[filename]['cisortrans'])
                    if cstfiles:
                        interaction_energy_addition(resi_poses[filename], ligand, \
                            sfcst, scored_set[cst_key][filename], \
                            pdb_dict[filename]['cisortrans'])

        #Generating wt_data for the two wt poses.
        wt_data = {}
        wt_data['total_True'] = total_energy(wt_pose['yeslig'], sf)
        wt_data['sasa_True'] = total_sasa(wt_pose['yeslig'])
        interaction_energy_addition(wt_pose['yeslig'], ligand, sf, wt_data,
                                    'True')

        prem = PerResidueEnergyMetric()
        prem.set_scorefunction(sf)
        e_vals = prem.calculate(wt_pose['yeslig'])
        wt_data['per_res_True'] = [ e_vals[i] for i in \
                range(1, wt_pose['yeslig'].total_residue() + 1) ]

        wt_data['total_False'] = total_energy(wt_pose['nolig'], sf)
        wt_data['sasa_False'] = total_sasa(wt_pose['nolig'])
        e_vals = prem.calculate(wt_pose['nolig'])
        wt_data['per_res_False'] = [ e_vals[i] for i in \
                range(1, wt_pose['yeslig'].total_residue() + 1) ]

        if cstfiles:
            wt_data['cst_total_True'] = total_energy(wt_pose['yeslig'], sfcst)
            interaction_energy_addition(wt_pose['yeslig'], ligand, \
                        sfcst, wt_data, 'True')
            wt_data['cst_total_False'] = total_energy(wt_pose['nolig'], sf)

            prem.set_scorefunction(sfcst)
            e_vals = prem.calculate(wt_pose['yeslig'])
            wt_data['cst_per_res_True'] = [ e_vals[i] for i in \
                    range(1, wt_pose['yeslig'].total_residue() + 1) ]
            e_vals = prem.calculate(wt_pose['nolig'])
            wt_data['cst_per_res_False'] = [ e_vals[i] for i in \
                    range(1, wt_pose['nolig'].total_residue() + 1) ]

        if args.out_file:
            print_out("Storing all the data into a pickle file")
            with open(args.out_file, 'wb') as savedfile:
                pickle.dump([scored_set, pdb_dict, group_set, wt_data],
                            savedfile)
            print_out("Pickle has stored files into " + str(args.out_file))

    if args.previously_run_files:
        print_out("Loading data pickled earlier from this code: " +
                  args.previously_run_files)
        with open(args.previously_run_files, 'rb') as loading_file:
            scored_set, pdb_dict, group_set, wt_data = pickle.load(
                loading_file)
        print_out("Loaded data into scored_set, pdb_dict, and group_set")
    """
    At this point, object -- scored_set -- is a dictionary of dictionaries that 
    contains all of the scored values - total score, sasa, and interaction score
    Goes by -- scored_set[residue_key][filename]
                -- this dictionary contains multiple types of scores,
                    - sasa, total, and inter_E
                -- for each filename.
    The object -- pdb_dict[any pdb filename] contains parsed information about the pdbs
            -'pdb', 'mut_res', 'uaa', 'cis/trans', 'contain_ligand', and 'n'
    The object -- group_set -- contains one set of information:
            -'resi' for the residue index, which contains a list of pdb file names.
    """

    ###Data analysis time -
    collated = {}
    value_keys = []
    #First will be making graphs of each pdb structure on a per residue basis
    #print(scored_set.keys())
    #print(wt_data.keys())
    out_directory(args.out_directory)
    if args.per_resi:
        for resikey in scored_set.keys():
            out_dir = '/'.join([args.out_directory, resikey])
            out_directory(out_dir)
            for prokey, scores in scored_set[resikey].items():
                pdb_to_png = prokey.split('/')[-1].replace('.pdb', '.png')
                png_name = '/'.join([out_dir, pdb_to_png])
                if 'yeslig' in prokey:
                    ref = 'per_res_True'
                else:
                    ref = 'per_res_False'
                if 'cst' in resikey:
                    ref = 'cst_' + ref

                plot_res_ddg(np.array(wt_data[ref]), \
                            np.array(scored_set[resikey][prokey]['per_resi']), \
                            name=png_name)

    #Removes the per_resi key to prevent confusion from going on in the next section
    for resikey in scored_set.keys():
        for prokey in scored_set[resikey].keys():
            scored_set[resikey][prokey].pop('per_resi', None)

    #First function parses the scored_sets of data and sorts into this format:
    #collated[residue_number][scoretype] = [list of scores]
    for resikey in scored_set:
        keysets = {}
        for filename in scored_set[resikey]:
            for k, v in scored_set[resikey][filename].items():
                if_list_key_not_in_dict(k, keysets)
                keysets[k].append(v)
                if k not in value_keys:
                    value_keys.append(k)
        collated[resikey] = keysets
    skip_keys = ['per_resi', 'cst_per_resi']
    averages = {}
    for residue, dictionary in collated.items():
        averages[residue] = {}
        for k, v in dictionary.items():
            if k in skip_keys:
                continue
            averages[residue][k] = sum(v) / len(v)

    print(averages.keys())
    sorted_residues = sorted( set( [int(x.split('_')[0][0:-1]) \
                                    for x in averages.keys() ] ) )
    sorted_cstkeys = [x for x in averages.keys() if 'cst' in x]
    sorted_cstkeys = sorted(sorted_cstkeys,
                            key=lambda x: float(x.split('_')[0][:-1]))
    sorted_reskeys = [x for x in averages.keys() if 'cst' not in x]
    sorted_reskeys = sorted(sorted_reskeys, key=lambda x: float(x[:-1]))
    print(sorted_reskeys)
    print(len(sorted_reskeys))
    print(sorted_cstkeys)
    print(len(sorted_cstkeys))
    print(sorted_residues)
    print(len(sorted_residues))

    #should be a list of numbers
    #Keytypes should be - total_C_True total_T_True sasa_C_True sasa_T_True
    #Keytypes should be - total_C_False total_T_False sasa_C_False sasa_T_False
    #Keytypes should be - interE_C interE_T
    out_file_name_pdb = args.wt_pdb.split("/")[-1].split('_')[0]
    #First thing to do is organize data by residue

    #First set of graphs are normalized against the wt.
    #sorted_resi
    for key_type in keysets:
        keybreak = key_type.split('_')
        wtkey = ''.join([keybreak[0], "_", keybreak[-1]])
        if keybreak[0] == 'interE':
            wtkey = ''.join([keybreak[0], "_True"])
        xlabel = 'REU'
        if 'sasa' in key_type:
            continue
        title = ' '.join(
            ["Normalized",
             str(key_type), 'across', out_file_name_pdb])
        make_bar_plot([averages[x][key_type] for x in sorted_reskeys], \
                sorted_residues, "Residue Index", xlabel, title, \
                control=wt_data[wtkey], out_dir=args.out_directory)
        if cstfiles:
            title = ' '.join(["Normalized", str(key_type), 'across', \
                                out_file_name_pdb, '- with csts'])
            make_bar_plot([averages[x][key_type] for x in sorted_cstkeys], \
                    sorted_residues, "Residue Index", xlabel, title, \
                    control=wt_data[wtkey], out_dir=args.out_directory)
            title = ' '.join(["Normalized", str(key_type), ' constraints for', \
                                out_file_name_pdb])
            make_bar_plot([averages[x+'_cst'][key_type] - averages[x][key_type] \
                    for x in sorted_reskeys], sorted_residues, "Residue Index", \
                    xlabel, title, control=wt_data[wtkey], out_dir=args.out_directory)

    #Second set of graphs are comparing Cis to Trans States
    for key in ['total']:
        xlabel = 'REU'
        if key == 'sasa':
            continue
        for tf in [False, True]:
            ckey = '_'.join([key, 'C', str(tf)])
            tkey = '_'.join([key, 'T', str(tf)])
            title = ' '.join(["Cis minus Trans", str(key), 'in', out_file_name_pdb, \
                    'with ligand',str(tf)])
            data_points = np.asarray([averages[x][ckey] - averages[x][tkey] \
                            for x in sorted_reskeys])
            cst_points = np.asarray([averages[x][ckey] - averages[x][tkey] \
                            for x in sorted_cstkeys])
            if not tf:
                high_mask = data_points >= 2.0
                low_mask = data_points <= -2.0
                make_bar_plot(data_points, sorted_residues, "Residue Index", xlabel,\
                title, out_dir=args.out_directory)
                if cstfiles:
                    title = ' '.join(["Cis minus Trans", str(key), 'in', \
                        out_file_name_pdb,'with ligand',str(tf),'with constraints'])
                    make_bar_plot(cst_points, sorted_residues, "Residue Index", \
                        xlabel, title, out_dir=args.out_directory)

            else:
                print(ckey)
                print(tf)
                make_bar_plot(data_points, sorted_residues, "Residue Index", xlabel,\
                    title, out_dir=args.out_directory, upper_mask=high_mask, \
                    lower_mask=low_mask, legend_info={'red':'nolig cis preferred', \
                    'green':'nolig trans preferred','blue':'nolig no preference'})

                if cstfiles:
                    title = ' '.join(["Cis minus Trans", str(key), 'in', \
                        out_file_name_pdb,'with ligand',str(tf),'with constraints'])
                    make_bar_plot(cst_points, sorted_residues, "Residue Index", \
                        xlabel, title, out_dir=args.out_directory, \
                        upper_mask=high_mask, lower_mask=low_mask, \
                        legend_info={'red':'nolig cis preferred', \
                        'green':'nolig trans preferred','blue':'nolig no preference'})

    key = 'interE'
    xlabel = 'REU'
    ckey = '_'.join([key, 'C'])
    tkey = '_'.join([key, 'T'])
    title = ' '.join(["Cis vs Trans", str('interE'), 'in', \
                    out_file_name_pdb, 'with ligand',str(tf)])
    make_bar_plot([averages[x][ckey] - averages[x][tkey] for x in sorted_reskeys], \
                sorted_residues, "Residue Index", xlabel, title, \
                out_dir=args.out_directory)

    if cstfiles:
        title = ' '.join(["Cis vs Trans", str('interE'), 'in', \
                    out_file_name_pdb, 'with ligand',str(tf),'and csts'])
        make_bar_plot([averages[x][ckey] - averages[x][tkey] for x in sorted_cstkeys],\
                    sorted_residues, "Residue Index", xlabel, title, \
                    out_dir=args.out_directory)

    #parser.add_argument('-lim', '--value_limits', type=float, required=False, \
    #Finding good data for generating scatterplots of apo vs bound states

    print_out("Setting up data for the scatterplot")
    partial_keysets = [
        '_'.join(x.split('_')[0:2]) for x in keysets if len(x.split('_')) >= 2
    ]
    print_out("Set of keysets")
    print(keysets.keys())
    print_out("WT data keys")
    print(wt_data.keys())
    for lines in keysets:
        if 'inter' in lines:
            partial_keysets.append(lines)

    partial_keysets = list(set(partial_keysets))
    for tf in ['True', 'False']:
        remove_residues = []
        for resi in sorted_reskeys:
            #ass    igning values to necessary checks
            tot_T_f = averages[resi]['total_T_' + tf]
            tot_C_f = averages[resi]['total_C_' + tf]
            wt_f = wt_data['total_' + tf]
            if abs(tot_T_f - wt_f) > args.value_limits and \
               abs(tot_C_f - wt_f) > args.value_limits:
                remove_residues.append(resi)


#Check to remove models where both models are bad in apo state
# sorted_cstkeys, sorted_reskeys, sorted_residues
    print_out('removing these residues:')
    print_out(remove_residues)
    for resi in set(remove_residues):
        sorted_residues.remove(int(resi[:-1]))
        sorted_cstkeys.remove(resi + '_cst')
        sorted_reskeys.remove(resi)
    print(sorted_residues)
    print(len(sorted_residues))
    print(sorted_reskeys)
    print(len(sorted_reskeys))
    print(sorted_cstkeys)
    print(len(sorted_cstkeys))

    x20 = [-25, 25]
    x20 = [-50, 50]
    x25 = [-15, 50]
    x5 = [-15, 15]
    x0 = [0, 0]
    apobound = {}
    print(sorted_residues)
    print(partial_keysets)
    print(averages['22A'].keys())
    for key_type in partial_keysets:
        tf = [key_type.split('_')[0] + "_" + x for x in ['False', 'True']]
        if 'inter' in key_type or 'sasa' in key_type:
            continue
        print_out('wt key - ' + str(tf))
        key_false = key_type + '_False'
        key_true = key_type + '_True'
        apo = [averages[x][key_false] - wt_data[tf[0]] for x in sorted_reskeys]
        apobound[key_false] = apo
        bound = [
            averages[x][key_true] - wt_data[tf[1]] for x in sorted_reskeys
        ]
        apobound[key_true] = bound
        title = key_type + " apo vs bound"
        make_scatterplot(np.asarray(apo), np.asarray(bound), 'apo REU', 'bound REU', \
                    title, outdir=args.out_directory, tick_labels=sorted_reskeys,\
                    xylimits=[x25,x25], lines=[[x20,x20]])
        if cstfiles:
            apo = [
                averages[x][key_false] - wt_data[tf[0]] for x in sorted_cstkeys
            ]
            apobound['cst_' + key_false] = apo
            bound = [
                averages[x][key_true] - wt_data[tf[1]] for x in sorted_cstkeys
            ]
            apobound['cst_' + key_true] = bound
            title = key_type + " apo vs bound - with constraints"
            make_scatterplot(np.asarray(apo), np.asarray(bound), 'apo REU', \
                    'bound REU', title, outdir=args.out_directory, \
                    tick_labels=sorted_reskeys, xylimits=[x25,x25], lines=[[x20,x20]])

    #apobound is normalized data apporpriate to the wt_data
    #major_trans = [apobound['total_T_True'][resi] - apobound['total_C_False'][resi] \
    #            for resi in range(0,len(apobound['total_T_True']))]
    #major_cis = [apobound['total_C_True'][resi] - apobound['total_T_False'][resi] \
    #            for resi in range(0,len(apobound['total_C_True']))]
    #title = 'Cis vs Trans states-- bound-apo energies'
    spaces = ' ' * 10 + 'vs' + ' ' * 10
    #make_scatterplot(np.asarray(major_trans), np.asarray(major_cis), \
    #    'Trans bound' + spaces + 'Cis apo', 'Cis bound' + spaces + 'Trans apo', title, \
    #    outdir=args.out_directory, tick_labels=sorted_reskeys, \
    #    xylimits=[x20,x20], lines=[[x20,x0],[x0,x20]])

    apo = [apobound['total_C_False'][resi] - apobound['total_T_False'][resi] \
                for resi in range(0,len(apobound['total_C_False']))]
    bound = [apobound['total_C_True'][resi] - apobound['total_T_True'][resi] \
                for resi in range(0,len(apobound['total_C_True']))]
    title = 'Cis minus Trans for apo vs bound'
    make_scatterplot(np.asarray(apo), np.asarray(bound), 'Cis apo' + spaces + \
            'Trans apo', 'Cis bound' + spaces + 'Trans bound', title, \
            outdir=args.out_directory, tick_labels=sorted_residues, \
            xylimits=[x20,x20], lines=[[x20,x0]])
    title = 'Cis minus Trans for apo vs bound - zoom'
    make_scatterplot(np.asarray(apo), np.asarray(bound), 'Cis apo' + spaces + \
            'Trans apo', 'Cis bound' + spaces + 'Trans bound', title, \
            outdir=args.out_directory, tick_labels=sorted_residues, \
            xylimits=[x5,x5], lines=[[x5,x0]])
    if cstfiles:
        apo = [apobound['cst_total_C_False'][resi] - \
                    apobound['cst_total_T_False'][resi] \
                    for resi in range(0,len(apobound['cst_total_C_False']))]
        bound = [apobound['cst_total_C_True'][resi] - \
                    apobound['cst_total_T_True'][resi] \
                    for resi in range(0,len(apobound['cst_total_C_True']))]
        title = 'Cis minus Trans for apo vs bound with csts'
        make_scatterplot(np.asarray(apo), np.asarray(bound), 'Cis apo' + spaces + \
                'Trans apo', 'Cis bound' + spaces + 'Trans bound', title, \
                outdir=args.out_directory, tick_labels=sorted_residues, \
                xylimits=[x20,x20], lines=[[x20,x0]])
        title = 'Cis minus Trans for apo vs bound with csts - zoom'
        make_scatterplot(np.asarray(apo), np.asarray(bound), 'Cis apo' + spaces + \
                'Trans apo', 'Cis bound' + spaces + 'Trans bound', title, \
                outdir=args.out_directory, tick_labels=sorted_residues, \
                xylimits=[x5,x5], lines=[[x5,x0]])