def loop_weights(energies, weights, scales, offsets, dec1, inp_dir2, scoretype2, curr_weights, list_weight_discs):
    if len(curr_weights) == len(energies):
        dec2 = scorefileparse.read_dir(inp_dir2, scoretype2, list_energies=energies, weights=curr_weights, scales=scales, offsets=offsets)
        [dec_inter1, dec_inter2] = scorefileparse.pdbs_intersect([dec1, dec2])
        merged = scorefileparse.merge_pdbs_dicts([dec_inter1, dec_inter2])

        list_discs = discparse.pdbs_dict_to_metrics(merged)
        for ind, weight in enumerate(curr_weights):
            list_weight_discs[ind].append(weight)
        list_weight_discs[-1].append(list_discs)
        return 

    energies_ind = len(curr_weights)

    list_weights = weights[energies_ind]
    for w in list_weights:
        temp_weights = curr_weights[:]
        temp_weights.append(w)
        loop_weights(energies, weights, scales, offsets, dec1, inp_dir2, scoretype2, temp_weights, list_weight_discs)
def main(args):
    #read in and rename arguments
    inp_dir=args[1]
    scoretype=args[2]

    dec, nat = scorefileparse.read_dec_nat(inp_dir, [], scoretype)

    disc = discparse.read_dir(inp_dir)

    dec_norm = scorefileparse.norm_pdbs(dec)
    nat_norm = scorefileparse.norm_pdbs(nat,dec)

    [dec_inter, nat_inter, disc_inter] = scorefileparse.pdbs_intersect([dec_norm, nat_norm, disc]) 

    #labels = ["Average","1.0","1.5","2.0","2.5","3.0","4.0","6.0"]
    labels = ["Average"]
    energy_gap = [[] for l in labels]
    avg_disc = [[] for l in labels]

    for pdb in dec_inter.keys():

        for ind in xrange(0,len(labels)):
            lowest_dec = min([ e[0] for e in dec_inter[pdb].values() ])
            lowest_nat = min([ n[0] for n in nat_inter[pdb].values() if n[1] < 2.0 ])
            energy_gap[ind].append(lowest_nat - lowest_dec)
            avg_disc[ind].append(disc_inter[pdb][0])

    fig, axarr = conv.create_ax(len(labels), 1)

    for x_ind,l in enumerate(labels):
        ax = axarr[0,x_ind] 

        scatterplot.draw_actual_plot(ax, avg_disc[x_ind], energy_gap[x_ind], [], l,"Disc","Energy Gap")

        scatterplot.plot_regression(ax, avg_disc[x_ind], energy_gap[x_ind], False, False)

    title = os.path.basename(inp_dir)

    filename=inp_dir + "/test.txt"

    conv.save_fig(fig, filename, "disc_v_egap", len(labels)*3, 4)
def main(args):
    #read in and rename arguments
    title1 = os.path.basename(args.input_dir_1)
    title2 = os.path.basename(args.input_dir_2)

    d1, n1 = scorefileparse.read_dec_nat(args.input_dir_1, [], args.scoretype1, True)
    d2, n2 = scorefileparse.read_dec_nat(args.input_dir_2, [], args.scoretype2, True)

    dec1 = scorefileparse.filter_pdbs_by_rmsd(d1, args.rmsd_cutoff)
    nat1 = scorefileparse.filter_pdbs_by_rmsd(n1, args.rmsd_cutoff)
    dec2 = scorefileparse.filter_pdbs_by_rmsd(d2, args.rmsd_cutoff)
    nat2 = scorefileparse.filter_pdbs_by_rmsd(n2, args.rmsd_cutoff)

    dec_norm1 = scorefileparse.norm_pdbs(dec1)
    nat_norm1 = scorefileparse.norm_pdbs(nat1,dec1)
    dec_norm2 = scorefileparse.norm_pdbs(dec2)
    nat_norm2 = scorefileparse.norm_pdbs(nat2,dec2)

    [dec_inter1, nat_inter1, dec_inter2, nat_inter2] = scorefileparse.pdbs_intersect([dec_norm1, nat_norm1, dec_norm2, nat_norm2]) 
    [dec_inter1, dec_inter2] = scorefileparse.pdbs_scores_intersect([dec_inter1, dec_inter2])       
    [nat_inter1, nat_inter2] = scorefileparse.pdbs_scores_intersect([nat_inter1, nat_inter2])       

    dec_filt1 = scorefileparse.filter_norm_pdbs(dec_norm1)
    nat_filt1 = scorefileparse.filter_norm_pdbs(nat_norm1)
    dec_filt2 = scorefileparse.filter_norm_pdbs(dec_norm2)
    nat_filt2 = scorefileparse.filter_norm_pdbs(nat_norm2)

    [dec_finter1, dec_finter2] = scorefileparse.pdbs_scores_intersect([dec_filt1, dec_filt2])
    [nat_finter1, nat_finter2] = scorefileparse.pdbs_scores_intersect([nat_filt1, nat_filt2])

    fig, axarr = conv.create_ax(2, len(dec_inter1))

    for x_ind,pdb in enumerate(sorted(dec_inter1.keys())):

        ax = axarr[x_ind, 0] 

	    plot(dec_inter1, dec_inter2, nat_inter1, nat_inter2, ax, pdb, title1, title2)

	    ax = axarr[x_ind, 1]

	    plot(dec_finter1, dec_finter2, nat_finter1, nat_finter2, ax, pdb, title1, title2)
def main(input_dir_1, scoretype1, input_dir_2, scoretype2, rmsd_cutoff, output_pre ):
    #read in and rename arguments
    title1 = os.path.basename(input_dir_1)
    title2 = os.path.basename(input_dir_2)

    d1, n1 = scorefileparse.read_dec_nat(input_dir_1, scoretype1, repl_orig=False)
    d2, n2 = scorefileparse.read_dec_nat(input_dir_2, scoretype2, repl_orig=False)

    dec1 = scorefileparse.filter_pdbs_by_rmsd(d1, rmsd_cutoff)
    nat1 = scorefileparse.filter_pdbs_by_rmsd(n1, rmsd_cutoff)
    dec2 = scorefileparse.filter_pdbs_by_rmsd(d2, rmsd_cutoff)
    nat2 = scorefileparse.filter_pdbs_by_rmsd(n2, rmsd_cutoff)

    dec_norm1 = scorefileparse.norm_pdbs(dec1)
    nat_norm1 = scorefileparse.norm_pdbs(nat1,dec1)
    dec_norm2 = scorefileparse.norm_pdbs(dec2)
    nat_norm2 = scorefileparse.norm_pdbs(nat2,dec2)

    [dec_inter1, nat_inter1, dec_inter2, nat_inter2] = scorefileparse.pdbs_intersect([dec_norm1, nat_norm1, dec_norm2, nat_norm2]) 
    [dec_inter1, dec_inter2] = scorefileparse.pdbs_scores_intersect([dec_inter1, dec_inter2])       
    [nat_inter1, nat_inter2] = scorefileparse.pdbs_scores_intersect([nat_inter1, nat_inter2])       

    dec_filt1 = scorefileparse.filter_norm_pdbs(dec_norm1)
    nat_filt1 = scorefileparse.filter_norm_pdbs(nat_norm1)
    dec_filt2 = scorefileparse.filter_norm_pdbs(dec_norm2)
    nat_filt2 = scorefileparse.filter_norm_pdbs(nat_norm2)

    [dec_finter1, dec_finter2] = scorefileparse.pdbs_scores_intersect([dec_filt1, dec_filt2])
    [nat_finter1, nat_finter2] = scorefileparse.pdbs_scores_intersect([nat_filt1, nat_filt2])

    fig, axarr = conv.create_ax(2, len(dec_inter1))

    line_plot_data = {}

    min_naive_by_pdb = {}

    for x_ind,pdb in enumerate(sorted(dec_inter1.keys())):

        ax = axarr[x_ind, 0] 

        plot_r_v_r(dec_inter1, dec_inter2, nat_inter1, nat_inter2, ax, pdb, title1, title2)

        ax = axarr[x_ind, 1]

        min_naive = plot_pareto(dec_inter1, dec_inter2, nat_inter1, nat_inter2, ax, pdb, title1, title2)
        keys_to_include = ["Amber", "Rosetta","All","Pareto10"]
        for key, (rank1, rank2, rmsd) in min_naive.items():
	     #if key not in keys_to_include:
	     #    continue
	     if line_plot_data.get(key) is None:
	         line_plot_data[key] = ([],[])
       	     line_plot_data[key][0].append(pdb)
	     line_plot_data[key][1].append(rmsd)
	     if min_naive_by_pdb.get(pdb) is None:
                 min_naive_by_pdb[pdb] = {}
             min_naive_by_pdb[pdb][key] = rmsd

    #organize data
    indices = list(range(len(line_plot_data["All"][1])))
    indices.sort(key=lambda x: line_plot_data["All"][1][x])
    
    ranked_pdbs_by_rmsd_all = {}

    for i, x in enumerate(indices):
        ranked_pdbs_by_rmsd_all[line_plot_data["All"][0][x]] = i

    for label, (pdbs, rmsds) in line_plot_data.items():
	line_plot_data[label] = tuple(zip(*sorted(zip(pdbs,rmsds), key=lambda x: ranked_pdbs_by_rmsd_all[x[0]] )))    

    filename = output_pre + "/" + title1 + "_" + title2 + ".txt"   
    
    #suffix="rmsd_v_rmsd_{0}".format(rmsd_cutoff)
 
    #conv.save_fig(fig, filename, suffix, 7, len(dec_inter1)*3)

    #plot line plot
    all_pareto_labels = []

    for initial in ["R","A"]:
        ordered_labels = ["All", "Amber", "Rosetta"]
        for i in range(1,11):
            ordered_labels.append("Pareto{0}{1}".format(initial,i))
            all_pareto_labels.append("Pareto{0}{1}".format(initial,i))
        
        lines = [ (line_plot_data[label][0], line_plot_data[label][1], label) for label in ordered_labels ]

        fig2, axarr2 = conv.create_ax(1, len(ordered_labels), shx=True, shy=True)

        for i, label in enumerate(ordered_labels):

            line.plot_series(axarr2[i,0], lines[0:i+1], "RMSD vs. pdb", "PDB", "RMSD", linestyle='')
    
            conv.add_legend(axarr2[i,0])
        conv.save_fig(fig2, filename, "_line_{0}".format(initial), 10, len(ordered_labels)*5)

    #plot histogram plot

    hist_comp = [ ("Amber","All"), ("Rosetta", "All"), ("ParetoR10", "All"), ("ParetoA10", "All")]

    hist_comp.extend([ ("ParetoR{0}".format(ind),"Rosetta") for ind in range(1,11) ])
    hist_comp.extend([ ("ParetoR{0}".format(ind),"Amber") for ind in range(1,11) ])
    hist_comp.extend([ ("ParetoA{0}".format(ind),"Rosetta") for ind in range(1,11) ])
    hist_comp.extend([ ("ParetoA{0}".format(ind), "Amber") for ind in range(1,11) ])

    fig3, axarr3 = conv.create_ax(2, len(hist_comp), shx=False, shy=False)

    for ind, (top, bottom) in enumerate(hist_comp):
        gen_dist_plot(axarr3[ind,0], axarr3[ind,1], top, bottom, min_naive_by_pdb)

    conv.save_fig(fig3, filename, "_distdeltas", 7, len(hist_comp)*5, tight=False)

    #plot scatterplot
    fig4, axarr4 = conv.create_ax(10, 2)
    for i in range(1,11):
        gen_scatterplot(axarr4[0,i-1], "ParetoR{0}".format(i), "Rosetta", "Amber", min_naive_by_pdb)
        gen_scatterplot(axarr4[1,i-1], "ParetoA{0}".format(i), "Rosetta", "Amber", min_naive_by_pdb)

    conv.save_fig(fig4, filename, "_scattdeltas", 30, 6)
def main(list_input_dirs, energies_names, output_pre):
    #read in and rename arguments
    inp_dir1=list_input_dirs[0][0]
    scoretype1=list_input_dirs[0][1]
    inp_dir2=list_input_dirs[1][0]
    scoretype2=list_input_dirs[1][1]

    title1 = os.path.basename(inp_dir1)
    title2 = os.path.basename(inp_dir2)

    column_dict = {}

    for c in energies_names:
        column_dict[c[0]] = c[1:]

    dec1, nat1 = scorefileparse.read_dec_nat(inp_dir1, energies_names[scoretype1], scoretype1)
    dec2, nat2 = scorefileparse.read_dec_nat(inp_dir2, energies_names[scoretype2], scoretype2)

    [dec_inter1, nat_inter1, dec_inter2, nat_inter2] = scorefileparse.pdbs_intersect([dec1, nat1, dec2, nat2]) 

    sum_discs = Counter()

    fig, axarr = conv.create_ax(1, len(dec_inter1)+1, True,True)

    for x_ind, pdb in enumerate(sorted(dec_inter1.keys())):

        discs_per_pdb = {}

        for w_1 in xrange(-10,10,2):
            for w_2 in xrange(-10,10,2): 
                weight_1 = 2 ** w_1
                weight_2 = 2 ** w_2
                weighted_1 = scorefileparse.weight_dict(dec_inter1[pdb], weight_1)
                weighted_2 = scorefileparse.weight_dict(dec_inter2[pdb], weight_2)
                merged = scorefileparse.merge_dicts([weighted_1, weighted_2])
                ddata1 = scorefileparse.convert_disc(merged)

                disc_divs = [1.0,1.5,2.0,2.5,3.0,4.0,6.0]

                disc1, d, counts = disc.given_data_run_disc(ddata1, True, disc_divs)
                discs_per_pdb[(weight_1,weight_2)] = disc1

        sorted_disc = sorted(discs_per_pdb.values())
        max_title = [ t for t,v in discs_per_pdb.items() if v == sorted_disc[0] ]
        
        #header_string = "\t".join("{0:.3f}-{1:.3f}".format(x,y) for x,y in sorted(discs_per_pdb.keys())) + "\tMax_Weight"
        #values_string = "\t".join(format(x, "10.3f") for (w1,w2),x in sorted(discs_per_pdb.items())) + "\t{0:.3f}".format(max_title[0])
        
        #print header_string
        #print values_string

        ax = axarr[x_ind, 0]

        #ax.set_xlim(-10, 600)
        #ax.set_ylim(-10, 600)

        ax.set_xscale('log', basex=2)
        ax.set_yscale('log', basey=2)

        x = [ w1 for (w1,w2) in sorted(discs_per_pdb.keys()) ]
        y = [ w2 for (w1,w2) in sorted(discs_per_pdb.keys()) ]
        d = [ v for k,v in sorted(discs_per_pdb.items()) ]
  
        min_y = min(discs_per_pdb.values())
        max_y = max(discs_per_pdb.values())
        #print min_y, max_y
        s = scatterplot.draw_actual_plot(ax, x, y, d, pdb, scoretype1, scoretype2, 'bwr')
        fig.colorbar(s,ax=ax)
        #ax.axhline(y=min_y)
        #ax.set_ylim(min_y-0.05,max_y+0.05)
        scatterplot.add_x_y_line(ax, 0,600)

        sum_discs.update(discs_per_pdb)

    #print "All PDBs {0}".format(len(dec_inter1))

    #sorted_disc = sorted(sum_discs.values())
    #max_title = [ t for t,v in sum_discs.items() if v == sorted_disc[0] ]

    #header_string = "\t".join(format(x, "10.3f") for x in sorted(sum_discs.keys())) + "\tMax_Weight"
    #values_string = "\t".join(format(x/len(dec_inter1), "10.3f") for key,x in sorted(sum_discs.items())) + "\t{0:.3f}".format(max_title[0])
  
    #print header_string
    #print values_string 

    ax = axarr[len(dec_inter1), 0]

    min_y = min(x/len(dec_inter1) for x in sum_discs.values())   
    max_y = max(x/len(dec_inter1) for x in sum_discs.values())

    x = [ w1 for w1,w2 in sorted(sum_discs.keys()) ]
    y = [ w2 for w1,w2 in sorted(sum_discs.keys()) ]
    d = [ v/len(dec_inter1) for k,v in sorted(sum_discs.items()) ]
    #fix titles of axes

    ax.set_xscale('log', basex=2)
    ax.set_yscale('log', basey=2)

    s = scatterplot.draw_actual_plot(ax, x,y,d, "All", scoretype1, scoretype2, cm='bwr')
    fig.colorbar(s,ax=ax)
    scatterplot.add_x_y_line(ax, 0,600)
    #ax.axhline(y=min_y)

    conv.save_fig(fig, output_pre, "_weights_v_disc", 3, len(dec_inter1)*3)