def side_by_side(args):
    mp_sc = Rf.score_file2df(args['mp_sc'])
    mp_pc = get_rmsds_from_table(args['mp_pc'])
    a = mp_sc.merge(mp_pc, on='description')
    mp_sc = a.copy()
    rs_sc = Rf.score_file2df(args['rs_sc'])
    rs_pc = get_rmsds_from_table(args['rs_pc'])
    b = rs_sc.merge(rs_pc, on='description')
    rs_sc = b.copy()
        
    names_dict = {}
    for i, d in enumerate(mp_sc['description'].values):
        names_dict[d] = i if 'MP' in d else i + 100
    for i, d in enumerate(rs_sc['description'].values):
        if d not in names_dict.keys():
            names_dict[d] = i if 'MP' in d else i + 100
    for k, v in names_dict.items():
        print(v, k)

    axmp = plt.subplot(121)
    axmp.scatter(mp_sc['pc_rmsd'].values, mp_sc['tot_mp_fa'].values, label=mp_sc['description'].values)
    for x, y, d in zip(mp_sc['pc_rmsd'].values, mp_sc['tot_mp_fa'], mp_sc['description']):
        axmp.annotate(names_dict[d], (x, y))
    # axmp.title('MP')

    axrs = plt.subplot(122)
    axrs.scatter(rs_sc['pc_rmsd'].values, rs_sc['tot_rs_fa'].values, label=rs_sc['description'].values)
    for x, y, d in zip(rs_sc['pc_rmsd'].values, rs_sc['tot_rs_fa'], rs_sc['description']):
        axrs.annotate(names_dict[d], (x, y))
    # axrs.title('RS')


    plt.show()
def side_by_side(args):
    mp_sc = Rf.score_file2df(args['mp_sc'])
    mp_pc = get_rmsds_from_table(args['mp_pc'])
    a = mp_sc.merge(mp_pc, on='description')
    mp_sc = a.copy()
    rs_sc = Rf.score_file2df(args['rs_sc'])
    rs_pc = get_rmsds_from_table(args['rs_pc'])
    b = rs_sc.merge(rs_pc, on='description')
    rs_sc = b.copy()

    names_dict = {}
    for i, d in enumerate(mp_sc['description'].values):
        names_dict[d] = i if 'MP' in d else i + 100
    for i, d in enumerate(rs_sc['description'].values):
        if d not in names_dict.keys():
            names_dict[d] = i if 'MP' in d else i + 100
    for k, v in names_dict.items():
        print(v, k)

    axmp = plt.subplot(121)
    axmp.scatter(mp_sc['pc_rmsd'].values, mp_sc['tot_mp_fa'].values, label=mp_sc['description'].values)
    for x, y, d in zip(mp_sc['pc_rmsd'].values, mp_sc['tot_mp_fa'], mp_sc['description']):
        axmp.annotate(names_dict[d], (x, y))
    # axmp.title('MP')

    axrs = plt.subplot(122)
    axrs.scatter(rs_sc['pc_rmsd'].values, rs_sc['tot_rs_fa'].values, label=rs_sc['description'].values)
    for x, y, d in zip(rs_sc['pc_rmsd'].values, rs_sc['tot_rs_fa'], rs_sc['description']):
        axrs.annotate(names_dict[d], (x, y))
    # axrs.title('RS')

    plt.show()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-mode', default='q')
    parser.add_argument('-pc_old')
    parser.add_argument('-score_old')
    parser.add_argument('-obj_old')
    parser.add_argument('-pc_new')
    parser.add_argument('-score_new')
    parser.add_argument('-obj_new')
    parser.add_argument('-pc_wj')
    parser.add_argument('-score_wj')
    parser.add_argument('-obj_wj')
    parser.add_argument('-sc')
    parser.add_argument('-mp_sc')
    parser.add_argument('-rs_sc')
    parser.add_argument('-pc')
    parser.add_argument('-mp_pc')
    parser.add_argument('-rs_pc')
    parser.add_argument('-span_threshold', default=0.3, type=float)
    parser.add_argument('-names', default=None, help='file list of names to show')
    parser.add_argument('-log_path', default='./', help='path to place log file')
    parser.add_argument('-percent', type=int, default=100)
    parser.add_argument('-best', type=bool, default=False)
    parser.add_argument('-terms', nargs='+', default=['score', 'a_shape', 'a_pack', 'a_ddg', 'res_solv'])
    parser.add_argument('-threshold', type=int, default=5)
    parser.add_argument('-show', default='show')

    args = vars(parser.parse_args())
    args['logger'] = Logger('logeer_%s.log' % time.strftime("%d.%0-m"), args['log_path'])

    if args['mode'] == 'old':
        analyse_old(args)

    elif args['mode'] == 'new':
        analyse_new(args)

    elif args['mode'] == 'wj':
        analyse_wj(args)

    elif args['mode'] == 'q':
        quick_rmsd_total(args)

    elif args['mode'] == 'slider':
        slide_ddg(args)

    elif args['mode'] == 's_by_s':
        side_by_side(args)

    elif args['mode'] == 'test':
        sc_df = Rf.score_file2df(args['sc'])
        new = Rf.get_best_of_best(sc_df)

    else:
        print('no mode')
    args['logger'].close()
示例#4
0
def compare_designs(args):
    all_seqs = read_seqs(args['fasta'], '.pdb')
    sc_df = Rf.score_file2df(args['sc'])
    best_by_aa_freq_df = Rf.get_best_num_by_term(sc_df, 10, 'a_tms_aa_comp')
    names_to_use = list(best_by_aa_freq_df['description'])
    names_to_use += [args['sc'].split('.')[0].split('all_')[1]]
    all_seqs = {k: v for k, v in all_seqs.items() if k in names_to_use}

    seqs_aa_freqs = {n: {} for n in all_seqs.keys() if n in names_to_use}
    freq_dists = {}
    for n, s in all_seqs.items():
        seqs_aa_freqs[n] = s.all_aas_frequencies(clean_zeros=True)
        # freq_dists[n] = calc_freq_dist(seqs_aa_freqs[n], wt_aa_freq,
        # 'AGFILPSTVWY')
        freq_dists[n] = calc_freq_dist(seqs_aa_freqs[n], wt_aa_freq)

    freq_dists = OrderedDict(sorted(freq_dists.items(), key=lambda t: t[1]))
    fig = plt.figure(figsize=(18, 18))
    nrows = math.floor(len(all_seqs.keys()) / 4)
    nrows += 1 if len(all_seqs.keys()) % 4 > 0 else 0
    i = 0
    # for n, s in all_seqs.items():
    first = True
    for n, dist_freq in freq_dists.items():
        if first:
            print(n)
            first = False
        s = all_seqs[n]
        plt.subplot(nrows, 4, 1 + i)
        fracs = [a for a in seqs_aa_freqs[n].values()]
        labels = [a for a in seqs_aa_freqs[n].keys()]
        colors = [color_map[a] for a in labels]
        plt.pie(fracs, labels=labels, autopct='%1.1f%%', colors=colors)
        name = n.split('.pdb')[0]
        if name.count('poly') == 2:
            p_name = name.split('_poly')[0]
        else:
            p_name = name
        if name in list(sc_df['description']):
            title = (
                '%s\nscore: %.0f ∆∆G: %.0f\nfreq_dist: %.2f\naa_comp %.2f' %
                (p_name, sc_df[sc_df['description'] == name]['score'],
                 sc_df[sc_df['description'] == name]['a_ddg'], dist_freq * 100,
                 best_by_aa_freq_df[best_by_aa_freq_df['description'] ==
                                    name]['a_tms_aa_comp']))
        else:
            title = 'freq_dist %.2f' % (dist_freq * 100)
        plt.title(title)

        i += 1

    if args['show'] == 'show':
        plt.show()
    else:
        plt.savefig('%s_aa_freq.png' % args['fasta'])
示例#5
0
def meta_pie(args: dict):
    mut_aa_freqs = {aa: 0 for aa in aas}
    score_files = [a for a in os.listdir(args['dir']) if '.score' in a]
    for fasta_file in [a for a in os.listdir(args['dir']) if '.fasta' in a]:
        temp_seqs = read_seqs('%s/%s' % (args['dir'], fasta_file),
                              remove_suffix='.pdb')
        temp_sc = Rf.score_file2df('%s/%s.score' %
                                   (args['dir'], fasta_file.split('.')[0]))
        names = list(
            Rf.get_best_num_by_term(temp_sc, 10,
                                    'a_tms_aa_comp')['description'])
        for n, aaseq in temp_seqs.items():
            if n in names:
                for aa in aas:
                    mut_aa_freqs[aa] += (aaseq.aa_frequency(aa) * len(aaseq))
    mut_aa_freqs_srt = OrderedDict(
        sorted(mut_aa_freqs.items(), key=lambda t: t[1]))

    ori_aa_freqs = {aa: 0 for aa in aas}
    for aaseq in read_seqs('/home/labs/fleishman/jonathaw/elazaridis/design/' +
                           'polyA_13Nov/chosen_from_all_27Feb/pdbs/' +
                           'all_dzns.fasta').values():
        for aa in aas:
            ori_aa_freqs[aa] += (aaseq.aa_frequency(aa) * len(aaseq))
    ori_aa_freqs_srt = OrderedDict(
        sorted(ori_aa_freqs.items(), key=lambda t: t[1]))

    plt.figure()
    plt.subplot(1, 3, 1)
    plt.title('natural TMs')
    plt.pie(list(wt_aa_freq.values()),
            labels=list(wt_aa_freq.keys()),
            autopct='%1.1f%%',
            colors=[color_map[a] for a in list(wt_aa_freq.keys())])
    plt.axis('equal')

    plt.subplot(1, 3, 2)
    plt.title('original designs')
    plt.pie(list(ori_aa_freqs_srt.values()),
            labels=list(ori_aa_freqs_srt.keys()),
            autopct='%1.1f%%',
            colors=[color_map[a] for a in list(ori_aa_freqs_srt.keys())])
    plt.axis('equal')

    plt.subplot(1, 3, 3)
    plt.title('mutated designs')
    plt.pie(list(mut_aa_freqs_srt.values()),
            labels=list(mut_aa_freqs_srt.keys()),
            autopct='%1.1f%%',
            colors=[color_map[a] for a in list(mut_aa_freqs_srt.keys())])
    plt.axis('equal')

    plt.show()
示例#6
0
def draw_hbonds_profiles():
    parser = argparse.ArgumentParser()
    parser.add_argument('-pdb')
    parser.add_argument('-stage', type=int)
    args = vars(parser.parse_args())

    pdb = my.parse_PDB(args['pdb'])

    if args['stage'] == 1:
        seq_length = pdb.seq_length()

        command = "for i in `seq 1 %i`;do ~/bin/fleish_sub_general.sh /home/labs/fleishman/jonathaw/Rosetta/main/source/bin/rosetta_scripts.default.linuxgccrelease -parser:protocol ~/elazaridis/protocols/scan_hbonds.xml -s %s -mp:scoring:hbond -corrections::beta_nov15 -score:elec_memb_sig_die -score:memb_fa_sol -overwrite -out:prefix ${i}_ -script_vars energy_function=beta_nov15_elazaridis res_num=${i} s1=%i e1=%i ori1=%s s2=%i e2=%i ori2=%s ;done" % (seq_length, args['pdb'], 1, 24, 'out2in', 25, 48, 'out2in')
        print('issuing command\n%s' % command)
        os.system(command)

    if args['stage'] == 2:
        os.system("head -2 1_score.sc|tail -1 > all_score.sc")
        os.system("grep SCORE: *_score.sc|grep -v des >> all_score.sc")
        z_dict = {id: res.memb_z for id, res in pdb.res_items()}
        pos_dict = {v: k for k, v in z_dict.items()}
        sc_df = Rf.score_file2df('all_score.sc')
        zs, scs = [], []
        for d, sc in zip(sc_df['description'].values, sc_df['a_e_res']):
            zs.append(z_dict[ int( d.split('_')[0] ) ])
            scs.append(sc)
        plt.scatter(zs, scs)
        for z, sc in zip(zs, scs):
            if z is not None:
                plt.annotate(pos_dict[z], xy=(z, sc), xytext=(-20, 20),
                            textcoords = 'offset points', ha = 'right', va = 'bottom',
                             bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
                             arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
        plt.show()
def slide_ddg(args):
    global new_df, radio, color_by, picked
    global scat, ax, sliders, sc_df, fig, cm, cbar
    sc_df = Rf.score_file2df(args['sc'], args['names'])
    args['logger'].log('score file has %i entries' % len(sc_df))
    if args['pc'] is not None:
        pc_df = get_rmsds_from_table(args['pc'])
        args['logger'].log('pc file had %i entries' % len(pc_df))
        a = sc_df.merge(pc_df, on='description')
        args['logger'].log('combined there are %i entries' % len(a))
        sc_df = a.copy()

    if args['percent'] != 100:
        threshold = np.percentile(sc_df[args['y']], args['percent'])
        sc_df = sc_df[ sc_df[args['y']] < threshold ]

    color_by = args['y']
    picked = False

    new_df = sc_df.copy()
    fig, ax = plt.subplots()
    plt.subplots_adjust(left=0.25, bottom=0.25)

    cm = plt.cm.get_cmap('RdYlBu')

    scat = ax.scatter(sc_df[args['x']].values, sc_df[args['y']].values, s=40, cmap=cm, c=sc_df[color_by], picker=True)
    cbar = plt.colorbar(scat)
    sliders = {}
    for i, term in enumerate(args['terms']):
        slider_ax = plt.axes([0.25, 0.01+i*0.035, 0.65, 0.03])
        sliders[term] = Slider(slider_ax, term, np.min(sc_df[term].values), np.max(sc_df[term].values), 0)
        sliders[term].on_changed(update)

    ax.set_xlim(np.min(new_df[args['x']].values)-1, np.max(new_df[args['x']].values)+1)
    ax.set_ylim(np.min(new_df[args['y']].values)-1, np.max(new_df[args['y']].values)+1)

    ax.set_xlabel(args['x'])
    ax.set_ylabel(args['y'])

    resetax = plt.axes([0.025, 0.7, 0.15, 0.15]) #[0.8, 0.025, 0.1, 0.04])
    button = Button(resetax, 'Reset', color='lightgoldenrodyellow', hovercolor='0.975')
    button.on_clicked(reset)

    printax = plt.axes([0.025, 0.3, 0.15, 0.15])
    printbutton = Button(printax, 'Print', color='green', hovercolor='red')
    printbutton.on_clicked(print_table)

    logax = plt.axes([0.025, 0.1, 0.15, 0.15])
    logbutton = Button(logax, 'log table', color='blue', hovercolor='red')
    logbutton.on_clicked(log_table)

    rax = plt.axes([0.025, 0.5, 0.15, 0.15], axisbg='white')
    radio = RadioButtons(rax, args['terms'], active=0)
    radio.on_clicked(colorfunc)

    # cbar = plt.colorbar(scat)
    pl = PointLabel(new_df, ax, fig, args['x'], args['y'], ['description', 'a_sasa', 'a_res_solv', 'a_pack', 'a_span_topo', 'a_ddg', 'fa_elec'], args['logger'])
    fig.canvas.mpl_connect('pick_event', pl.onpick)

    plt.show()
def remove_pdbs_only(pwd, name, time):
    pdb_files = [a for a in os.listdir(pwd) if a[-7:] == '.pdb.gz']
    run_filters = DCRP.generate_run_filters(
        args={
            'ddg': 24.0,
            'sasa': 1400,
            'shape': 0.6,
            'packstat': 0.6,
            'buried_2': 3,
            'hbonds': -10.
        })
    score_dict = RF.score2dict('%sall_%s_%s.score' % (pwd, name, time))
    passed, failed = DCRP.all_who_pass_run_filters({}, score_dict, run_filters)
    if len(passed) != 0:
        print('there are %i passed scores, so choosing from there' %
              len(list(passed.keys())))
        best_structs = DCRP.best_n_structures(
            {
                'filter': 'ddg',
                'n': min([10, len(list(passed.keys()))])
            }, passed)
    else:
        print('there are no passed, so choosing from the failed')
        best_structs = DCRP.best_n_structures(
            {
                'filter': 'ddg',
                'n': min([5, len(list(failed.keys()))])
            }, failed)
    best_names = [a['description'] for a in best_structs]
    [os.remove(pdb) for pdb in pdb_files if pdb[:-7] not in best_names]
def process_folder(args):
    import time
    pwd = os.getcwd()+'/'
    os.chdir(args['folder'])

    name = pwd.split('/')[-2]
    time = time.strftime("%d.%0-m")

    if not args['force_process']:
        if os.path.isfile('%sall_%s_%s.err' % (pwd, name, time)):
            print('found %sall_%s_%s.err, STOPPING' % (pwd, name, time))
            if args['remove_pdbs']:
                remove_pdbs_only(pwd, name, time)
                sys.exit()
            return 'not finished'
        if not is_folder_finished(pwd):
            return 'not finished'
    sc_files = [a for a in os.listdir(pwd) if a[-3:] == '.sc']
    pdb_files = [a for a in os.listdir(pwd) if a[-7:] == '.pdb.gz']
    err_files = [a for a in os.listdir(pwd) if a[:4] == 'err.']
    job_files = [a for a in os.listdir(pwd) if a[:4] == 'job.']
    print('found a total of %i job files, %i err files, %i pdbs and %i scores' % (len(job_files), len(err_files),
                                                                                  len(pdb_files), len(sc_files)))
    if len(sc_files) == 0:
        return 'no scores'
    combine_scores('%sall_%s_%s.score' % (pwd, name, time), sc_files)
    non_triv_errs = process_errors('%sall_%s_%s.err' % (pwd, name, time), err_files)

    if non_triv_errs == 0:
        print('removing out.* and job.*')
        [os.remove(out) for out in os.listdir(pwd) if out[:4] == 'out.']
        [os.remove(job) for job in job_files]
        try:
            os.remove('./command')
        except:
            pass

    run_filters = DCRP.generate_run_filters(args={'ddg': 24.0, 'sasa': 1400, 'shape': 0.6, 'packstat': 0.6,
                                                  'buried_2': 3, 'hbonds': -10.})
    score_dict = RF.score2dict('%sall_%s_%s.score' % (pwd, name, time))
    passed, failed = DCRP.all_who_pass_run_filters({}, score_dict, run_filters)
    if len(passed) != 0:
        print('there are %i passed scores, so choosing from there' % len(list(passed.keys())))
        best_structs = DCRP.best_n_structures({'filter': 'ddg', 'n': min([10, len(list(passed.keys()))])}, passed)
    else:
        print('there are no passed, so choosing from the failed')
        best_structs = DCRP.best_n_structures({'filter': 'ddg', 'n': min([10, len(list(failed.keys()))])}, failed)
    best_names = [a['description'] for a in best_structs]
    print('the best:', best_names)
    print('removing all other pdbs')
    [os.remove(pdb) for pdb in pdb_files if pdb[:-7] not in best_names]
    os.chdir(pwd)
示例#10
0
def get_dielectric_from_rosetta():
    df = rf.score_file2df('scores/all_scores.score')
    ds = [float(a.split('_')[0]) for a in df['description'].values]
    zs = [float(a.split('_')[1]) for a in df['description'].values]
    es = df['fa_elec'].values

    ds_, zs_, es_ = [], [], []
    for d, z, e in zip(ds, zs, es):
        if 1: #d % 1 == 0 or d % 1 == 0.5:
            ds_.append(d)
            zs_.append(z)
            es_.append(e)
    return ds_, zs_, es_
def get_dielectric_from_rosetta():
    df = rf.score_file2df('scores/all_scores.score')
    ds = [float(a.split('_')[0]) for a in df['description'].values]
    zs = [float(a.split('_')[1]) for a in df['description'].values]
    es = df['fa_elec'].values

    ds_, zs_, es_ = [], [], []
    for d, z, e in zip(ds, zs, es):
        if 1: #d % 1 == 0 or d % 1 == 0.5:
            ds_.append(d)
            zs_.append(z)
            es_.append(e)
    return ds_, zs_, es_
def erbb2_mutants(args):
    """
    draw what rosetta thinks about assaf's mutants of ErbB2
    """
    score_dir = "/home/labs/fleishman/jonathaw/elazaridis/fold_and_dock/" + \
        "erbb2/mutations/all_results/"
    exp_table = "/home/labs/fleishman/jonathaw/elazaridis/fold_and_dock/" + \
        "erbb2/mutations/general_data/mut_table.txt"
    exp_df = parse_erbb2_exp_table(exp_table)
    wt_score_file = score_dir + "all_erbb2v4_wt_28Feb.score"
    wt_df = Rf.score_file2df(wt_score_file)
    wt_ddg = Rf.get_term_by_threshold(wt_df, 'score', 5, 'a_ddg', 'mean')

    exp_df['rosetta'] = np.nan
    # exp_df['rosetta_score'] = np.nan
    for sc_file in [a for a in os.listdir(score_dir)
                    if '.score' in a and 'wt' not in a]:
        df = Rf.score_file2df(score_dir+sc_file)
        ddg = Rf.get_term_by_threshold(df, 'score', 5, 'a_ddg', 'mean')
        # scr = Rf.get_term_by_threshold(df, 'score', 5, 'score', 'mean')
        name = sc_file.split('_')[2]
        # print(sc_file, name)
        wt = name[0]
        pos = int(name[1:-1])
        mut = name[-1]
        exp_df.set_value((exp_df['pos'] == pos) & (exp_df['wt'] == wt) &
                         (exp_df['mut'] == mut), 'rosetta', ddg-wt_ddg)
    print(exp_df)
    exp_df = exp_df.dropna()
    print(exp_df.to_string())
    plt.scatter(exp_df['rosetta'], exp_df['exp'])
    plt.ylabel('experimental ∆∆G')
    plt.xlabel('rosetta ∆∆G')
    plt.axhline(0)
    plt.axvline(0)
    for i, row in exp_df.iterrows():
        plt.annotate('%s%i%s' % (row['wt'], row['pos'], row['mut']),
                     (row['rosetta'], row['exp']))
    plt.show()
def remove_pdbs_only(pwd, name, time):
    pdb_files = [a for a in os.listdir(pwd) if a[-7:] == '.pdb.gz']
    run_filters = DCRP.generate_run_filters(args={'ddg': 24.0, 'sasa': 1400, 'shape': 0.6, 'packstat': 0.6,
                                                  'buried_2': 3, 'hbonds': -10.})
    score_dict = RF.score2dict('%sall_%s_%s.score' % (pwd, name, time))
    passed, failed = DCRP.all_who_pass_run_filters({}, score_dict, run_filters)
    if len(passed) != 0:
        print('there are %i passed scores, so choosing from there' % len(list(passed.keys())))
        best_structs = DCRP.best_n_structures({'filter': 'ddg', 'n': min([10, len(list(passed.keys()))])}, passed)
    else:
        print('there are no passed, so choosing from the failed')
        best_structs = DCRP.best_n_structures({'filter': 'ddg', 'n': min([5, len(list(failed.keys()))])}, failed)
    best_names = [a['description'] for a in best_structs]
    [os.remove(pdb) for pdb in pdb_files if pdb[:-7] not in best_names]
def read_all_scores(file_name: str, args=dict()) -> pd.DataFrame:
    """
    if necessary gather scores, and coalesce into dataframe with resiudes, ds and zs
    """
    if not os.path.exists(args["sc"]):
        args["logger"].logger("bo score file found. gathering scores to %s" % args["sc"])
        os.system("grep description scores/%s > %s" % (args["sc"], [a for a in os.listdir("scores/") if ".sc" in a][0]))
        os.system("grep SCORE: scores/*.sc | grep -v description >> %s" % args["sc"])
    df = rf.score_file2df(file_name)
    args["logger"].log("found %i entries in score file" % len(df))
    desc_spl = df["description"].str.split("_")
    df["aa1"] = desc_spl.str.get(0)
    df["aa2"] = desc_spl.str.get(1)
    df["d"] = desc_spl.str.get(2).astype(np.float64)
    df["z"] = desc_spl.str.get(3).astype(np.float64)
    df = df.round({"d": 2, "z": 2})
    return df
def slide_ddg(args):
    global scat, ax, ddg_slider, sc_df, fig
    sc_df = Rf.score_file2df(args['sc'], args['names'])
    pc_df = get_rmsds_from_table(args['pc'])
    a = sc_df.merge(pc_df, on='description')
    sc_df = a.copy()
    
    fig, ax = plt.subplots()
    plt.subplots_adjust(left=0.25, bottom=0.25)

    scat = ax.scatter(sc_df['pc_rmsd'].values, sc_df['score'].values)
    # sc_df.plot(kind='scatter', x='pc_rmsd', y='score')

    slider_ax = plt.axes([0.25, 0.1, 0.65, 0.03])
    ddg_slider = Slider(slider_ax, 'ddG', np.min(sc_df['a_ddg'].values), +5, 0) #np.max(sc_df['a_ddg'].values), valinit=-8)

    ddg_slider.on_changed(update)
    plt.show()
def read_all_scores(file_name: str, args=dict()) -> pd.DataFrame:
    """
    if necessary gather scores, and coalesce into dataframe with resiudes, ds and zs
    """
    # if not os.path.exists(args['sc']):
    # args['logger'].logger('bo score file found. gathering scores to %s' % args['sc'])
    # os.system('grep description scores/%s > %s' %
    # ( args['sc'], [a for a in os.listdir('scores/') if '.sc' in a][0] ))
    # os.system('grep SCORE: scores/*.sc | grep -v description >> %s' % args['sc'])
    df = rf.score_file2df(file_name)
    args['logger'].log('found %i entries in score file' % len(df))
    desc_spl = df['description'].str.split('_')
    df['aa1'] = desc_spl.str.get(0)
    df['aa2'] = desc_spl.str.get(1)
    df['d'] = desc_spl.str.get(2).astype(np.float64)
    df['z'] = desc_spl.str.get(3).astype(np.float64)
    df = df.round({'d': 2, 'z': 2})
    return df
def slide_ddg(args):
    global scat, ax, ddg_slider, sc_df, fig
    sc_df = Rf.score_file2df(args['sc'], args['names'])
    pc_df = get_rmsds_from_table(args['pc'])
    a = sc_df.merge(pc_df, on='description')
    sc_df = a.copy()

    fig, ax = plt.subplots()
    plt.subplots_adjust(left=0.25, bottom=0.25)

    scat = ax.scatter(sc_df['pc_rmsd'].values, sc_df['score'].values)
    # sc_df.plot(kind='scatter', x='pc_rmsd', y='score')

    slider_ax = plt.axes([0.25, 0.1, 0.65, 0.03])
    ddg_slider = Slider(slider_ax, 'ddG', np.min(sc_df['a_ddg'].values), +5, 0) #np.max(sc_df['a_ddg'].values), valinit=-8)

    ddg_slider.on_changed(update)
    plt.show()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-sc', type=str, help='score file')
    parser.add_argument('-percent', type=float, default=5, help='percent (1-100) best scoring to get')
    parser.add_argument('-filter', type=str, default='score', help='filter or score term to use')
    parser.add_argument('-num', default=10, type=int, help='use if you want a number of results, not better than percentile')
    parser.add_argument('-mode', default='%')
    parser.add_argument('-over_under', type=str, default='under', help='under/over score should be over/under threshold')
    parser.add_argument('-result', type=str, default=None, help='should the names be written to a file separate from the log file')
    args = vars(parser.parse_args())
    
    logger = Logger('top_%.1f_%s.log' % (args['percent'], args['filter']))

    # read in the score file, determine the threshold for the percentile
    sc_df = Rf.score_file2df(args['sc'])
    score = sc_df[args['filter']]


    if args['mode'] == '%':
        threshold = np.percentile(score, args['percent'])
        logger.log('found a threshold for %f for filter %s to be %.2f' % (args['percent'], args['filter'], threshold))

        # create a df for lines that pass the threshold, either over or above it...
        if args['over_under'] == 'over':
            pass_df = sc_df[sc_df[args['filter']] >= threshold]
        elif args['over_under'] == 'under':
            pass_df = sc_df[sc_df[args['filter']] <= threshold]

    if args['mode'] == 'num':
        sc_df.sort_values(args['filter'], inplace=True)
        pass_df = sc_df.head(args['num'])

    # output the names (description) of models that pass the threshold, either to the logger file, or to a separate file
    if args['result'] is None:
        logger.create_header('models passing the threshold:')
        for name in pass_df['description']:
            logger.log(name, skip_stamp=True)
    else:
        with open(args['result'], 'w+') as fout:
            for name in pass_df['description']:
                fout.write(name + '\n')
    logger.close()
def CaCa_d_by_z(args: dict) -> None:
    # reading the Ca-Ca reuslts, from args['sc']
    df = rf.score_file2df(args['sc'])
    desc = df['description'].str.split('_')
    df['d'] = desc.str.get(0).astype(np.float64)
    df['z'] = desc.str.get(1).astype(np.float64)

    # get data from spline log
    spline_log = parse_rosetta_log(args)
    spline_log_df = pd.DataFrame({
        'z': spline_log[0],
        'd': spline_log[1],
        'e': spline_log[2]
    })

    # get DSRK data as well
    dslk_df = read_all_scores(args['dslk_scores'], args)

    ds_sorted = np.round(sorted(list(set(df['d'].values))), 2)
    zs_sorted = np.round(sorted(list(set(df['z'].values))), 2)
    D, Z, E = [], [], []
    model_E = []
    spline_log_E = []
    dre, dke, sre, ske = [], [], [], []
    for d in ds_sorted:
        if d >= 10:
            continue
        for z in zs_sorted:
            # if z > 5: continue
            D.append(d)
            Z.append(z)
            E.append(df[(df['d'] >= d - 0.01) & (df['d'] <= d + 0.01) &
                        (df['z'] == z)]['fa_elec'].values[0] / 4)
            model_E.append(rosetta_dz_model(d, z))
            spline_log_E.append(
                spline_log_df[(spline_log_df['d'] >= d - 0.01)
                              & (spline_log_df['d'] <= d + 0.01) &
                              (spline_log_df['z'] == z)]['e'].values[0])

    # gather DSLK data in lists to make plots later
    dslk_D, dslk_Z = [], []
    dre, dke, sre, ske = [], [], [], []
    for d in np.round(sorted(list(set(dslk_df['d'].values))), 2):
        for z in np.round(sorted(list(set(dslk_df['z'].values))), 2):
            dslk_D.append(d)
            dslk_Z.append(z)
            dre.append(dslk_df[(dslk_df['aa1'] == 'D')
                               & (dslk_df['aa2'] == 'R') & (dslk_df['d'] == d)
                               & (dslk_df['z'] == z)]['fa_elec'].values[0])
            dke.append(dslk_df[(dslk_df['aa1'] == 'D')
                               & (dslk_df['aa2'] == 'K') & (dslk_df['d'] == d)
                               & (dslk_df['z'] == z)]['fa_elec'].values[0])
            sre.append(dslk_df[(dslk_df['aa1'] == 'D')
                               & (dslk_df['aa2'] == 'R') & (dslk_df['d'] == d)
                               & (dslk_df['z'] == z)]['fa_elec'].values[0])
            ske.append(dslk_df[(dslk_df['aa1'] == 'D')
                               & (dslk_df['aa2'] == 'K') & (dslk_df['d'] == d)
                               & (dslk_df['z'] == z)]['fa_elec'].values[0])

    args['logger'].log('finished preparing lists')
    master_df = pd.DataFrame({
        'd': D,
        'z': Z,
        'e_caca': E,
        'model_e': model_E,
        'spline_log_e': spline_log_E
    })

    if args['plot_type'] == '1ds':
        args['logger'].log('making 1d plots')
        fig = plt.figure()
        plt.subplots_adjust(left=None,
                            bottom=None,
                            right=None,
                            top=None,
                            wspace=0.15,
                            hspace=0.45)
        i = 0
        for i, z in enumerate(zs_sorted[::2]):
            plt.subplot(5, 4, 1 + i)
            d_df = master_df[master_df['z'] == z]
            plt.plot(d_df['d'], d_df['e_caca'].values, label='e_caca', c='r')
            plt.plot(d_df['d'], d_df['model_e'].values, label='model_e', c='b')
            plt.plot(d_df['d'],
                     d_df['spline_log_e'].values,
                     label='spline_log_e',
                     c='g')
            plt.title('z=%.2f' % z)
            plt.xlabel('d')
            plt.ylabel('e')
        plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        plt.show()

    if args['plot_type'] == 'contourf':
        D_, Z_ = np.meshgrid(D, Z)

        levels = np.linspace(np.min(E), np.max(E), 1000)

        E_ = sc.interpolate.griddata((D, Z), E, (D_, Z_), method='cubic')
        model_E_ = sc.interpolate.griddata((D, Z),
                                           model_E, (D_, Z_),
                                           method='cubic')
        spline_log_E_ = sc.interpolate.griddata((D, Z),
                                                spline_log_E, (D_, Z_),
                                                method='cubic')

        fig = plt.figure()
        plt.subplot(1, 2, 1)
        cs = plt.contourf(D_, Z_, E_, levels=levels, cmap=cm.coolwarm)
        plt.xlabel('d (A)')
        plt.ylabel('z (A)')
        fig.colorbar(cs, format="%.2f")

        plt.subplot(1, 2, 2)
        cs = plt.contourf(D_, Z_, model_E_, levels=levels, cmap=cm.coolwarm)
        plt.xlabel('d (A)')
        plt.ylabel('z (A)')
        fig.colorbar(cs, format="%.2f")
        plt.show()
def quick_rmsd_total(args):

    y_axis_term = 'score'

    sc_df = Rf.score_file2df(args['sc'], args['names'])
    args['logger'].log('found %i structs in sc_df' % len(sc_df))
    pc_df = get_rmsds_from_table(args['pc'])
    args['logger'].log('found %i structs in pc' % len(pc_df))
    a = sc_df.merge(pc_df, on='description')
    sc_df = a.copy()

    # if 'a_hha' in sc_df.columns:
        # sc_df['angle'] = sc_df['a_hha'] > 0

    args['logger'].log('left with %i in merged df' % len(sc_df))

    args['logger'].log('examining %s with span_topo threshold %f' % (args['sc'], args['span_threshold']))
    fig, ax = plt.subplots()

    if args['best']:
        sc_df = sc_df[sc_df['a_tms_span_fa'] > 0.5 ]
        threshold = np.percentile(sc_df[y_axis_term], args['percent'])
        sc_df = sc_df[ sc_df[y_axis_term] < threshold ]
        sc_df = sc_df[ sc_df['a_span_topo'] >= 0.99 ]
        sc_df_pass = Rf.get_best_of_best(sc_df, args['terms'], args['threshold'])
        sc_df_fail = sc_df[ ~sc_df['description'].isin( sc_df_pass['description'] ) ]
        args['logger'].log('%i models returned from BEST' % len(sc_df_pass))
    else:
        args['logger'].log('total of %i models in score' % len(sc_df))
        sc_df = sc_df[sc_df['a_tms_span_fa'] > 0.5]
        args['logger'].log('%i models pass tms_span' % len(sc_df))
        threshold = np.percentile(sc_df[y_axis_term], args['percent'])
        sc_df = sc_df[ sc_df[y_axis_term] < threshold ]
        args['logger'].log('for percent %.2f found threshold to be %.2f and %i strucutres pass it' % (args['percent'], threshold, len(sc_df)))
        sc_df = sc_df[sc_df['a_shape'] >= 0.6]
        sc_df = sc_df[sc_df['a_sasa'] > 700]
        args['logger'].log('%i passed sasa 600' % len(sc_df))
        sc_df = sc_df[sc_df['a_ddg'] < -5]
        args['logger'].log('%i passed ddg' % len(sc_df))
        # sc_df = sc_df[sc_df['a_pack'] > 0.6]
        sc_df = sc_df[sc_df['a_unsat'] < 1]
        args['logger'].log('%i passed unsat' % len(sc_df))
        sc_df['pass'] = sc_df['a_span_topo'] > args['span_threshold']
        sc_df = sc_df[sc_df['a_res_solv'] < -10]
        args['logger'].log('%i passed res_solv -10' % len(sc_df))

        sc_df_pass = sc_df[sc_df['a_span_topo'] > args['span_threshold']]
        args['logger'].log('%i models passed span_topo threshold' % len(sc_df_pass))
        sc_df_fail = sc_df[sc_df['a_span_topo'] <= args['span_threshold']]
        args['logger'].log('%i models failed span_topo threshold' % len(sc_df_fail))

    # ax.scatter(sc_df_fail['rmsd_calc'].values, sc_df_fail['score'].values, color='r', marker='.')

    x_array = np.ndarray(buffer=sc_df_pass['pc_rmsd'].values, shape=(len(sc_df),))
    y_array = np.ndarray(buffer=sc_df_pass[y_axis_term].values, shape=(len(sc_df)))
    if 'a_hha' in sc_df.columns:
        ax.scatter(sc_df_pass['pc_rmsd'].values, sc_df_pass[y_axis_term].values, marker='o',
                c=sc_df_pass['a_hha'].values, picker=True, cmap=plt.cm.coolwarm)
    else:
        ax.scatter(sc_df_pass['pc_rmsd'].values, sc_df_pass[y_axis_term].values, marker='o',
                c=sc_df_pass['a_span_topo'].values, picker=True, cmap=plt.cm.coolwarm)

    # min_energy = np.nanmin(list(sc_df_pass['score'].values)+list(sc_df_fail['score'].values))
    min_energy = np.nanmin(list(sc_df_pass[y_axis_term].values))
    max_energy = np.nanmax(list(sc_df_pass[y_axis_term].values))
    plt.ylim([min_energy - 1, max_energy + 1])
    plt.xlim([0, 15])
    plt.title(args['sc']+'_pass')

    z_score, rmsd_threshold = Rf.get_z_score_by_rmsd_percent(sc_df_pass)
    plt.text(0.75, 0.2, "Zscore=%.2f" % z_score, transform=ax.transAxes)
    plt.axvline(rmsd_threshold)
    # if 'a_hha' in sc_df.columns:
        # ax.scatter(sc_df_fail['pc_rmsd'].values, sc_df_fail[y_axis_term].values, marker='x',
                # c=sc_df_fail['a_hha'].values, picker=True, cmap=plt.cm.coolwarm, s=5, alpha=90)#, markersize=200)
    # else:
        # ax.scatter(sc_df_fail['pc_rmsd'].values, sc_df_fail[y_axis_term].values, marker='x',
                # c=sc_df_fail['a_span_topo'].values, picker=True, cmap=plt.cm.coolwarm, s=5, alpha=90)#, markersize=200)

    # af = PrintLabel(sc_df_pass, 'rmsd_calc', 'score', ['description', 'pass'])
    # fig.canvas.mpl_connect('button_press_event', af)
    point_label_cols = list(set(args['terms'] + ['description', 'a_sasa', 'a_res_solv', 'a_pack', 'a_span_topo', 'a_ddg', 'fa_elec']))
    pl = PointLabel(sc_df_pass, ax, fig, 'pc_rmsd', y_axis_term, point_label_cols,
                    args['logger']) # a_shape ???
    fig.canvas.mpl_connect('pick_event', pl.onpick)
    plt.xlabel('RMSD')
    plt.ylabel(y_axis_term)
    if args['show'] == 'show':
        # fig.canvas.mpl_connect('pick_event', on_pick3)
        # cursor = FollowDotCursor(ax, sc_df_pass['pc_rmsd'], sc_df_pass[y_axis_term])
        plt.show()
    else:
        plt.savefig('%s.png' % args['sc'].split('.score')[0])
def design_fnd_scatter(args):
    sc_df = Rf.score_file2df(args['sc'], args['names'])
    args['logger'].log('found %i structs in sc_df' % len(sc_df))
    pc_df = get_rmsds_from_table(args['pc'])
    args['logger'].log('found %i structs in pc' % len(pc_df))
    a = sc_df.merge(pc_df, on='description')
    sc_df = a.copy()
    sc_df = sc_df[sc_df['a_tms_span'] > 0.5]

    threshold = np.percentile(sc_df['score'], args['percent'])
    sc_df = sc_df[sc_df['score'] < threshold]

    original_df = Rf.score_file2df(args['original_sc'])
    for d in original_df['description']:
        if args['name'] in d:
            row_name = d
    original_row = original_df[original_df['description'] == row_name]
    term_dict = {'total_score': {'ou': 'under'},
                 'a_sasa': {'ou': 'over'},
                 'a_pack': {'ou': 'over'},
                 # 'a_shape': {'ou': 'over'},
                 'a_res_solv': {'ou': 'under'},
                 'a_ddg': {'ou': 'under'},
                 'a_span_topo': {'ou': 'over'}}
    for term in term_dict.keys():
        if term == 'a_res_solv':
            term_dict[term]['threshold'] = 0.5 * original_row[term].values[0]
        else:
            term_dict[term]['threshold'] = 0.8 * original_row[term].values[0]

    sc_df, fail_msg = Rf.remove_failed_dict(sc_df, term_dict)
    for k, v in fail_msg.items():
        print(v)

    if args['mode'] == 'all_fnds':
        return sc_df

    fig, ax = plt.subplots()
    ax.scatter(sc_df['pc_rmsd'].values, sc_df['score'].values, marker='o',
               c=sc_df['a_span_topo'].values, picker=True, cmap=plt.cm.coolwarm)

    min_energy = np.nanmin(list(sc_df['score'].values))
    max_energy = np.nanmax(list(sc_df['score'].values))
    plt.ylim([min_energy - 1, max_energy + 1])
    plt.xlim([0, 15])
    plt.title(args['name'])

    z_score, rmsd_threshold = rf.get_z_score_by_rmsd_percent(sc_df)
    plt.text(0.75, 0.2, "zscore=%.2f" % z_score, transform=ax.transaxes)
    plt.axvline(rmsd_threshold)

    point_label_cols = list(set(args['terms'] + ['description', 'a_sasa',
                                                 'a_res_solv', 'a_pack',
                                                 'a_span_topo', 'a_ddg',
                                                 'fa_elec']))
    pl = PointLabel(sc_df, ax, fig, 'pc_rmsd', 'score', point_label_cols,
                    args['logger']) # a_shape ???
    fig.canvas.mpl_connect('pick_event', pl.onpick)

    if args['show'] == 'show':
        plt.show()
    else:
        plt.savefig('%s.png' % args['name'])
def mutant_table( args: dict ):
    """
    a function to find and display the correlation between ResSolv and
    MPFrameWork and experimental results from both Doung 2006 and Assaf
    """
    scores_dir = '/home/labs/fleishman/jonathaw/elazaridis/fold_and_dock/gpa/mutant_results/%s' % args['dir']
    mp_dir = '/home/labs/fleishman/jonathaw/elazaridis/fold_and_dock/gpa/mutant_results/mpframework_18Dec/'
    main_df = pd.read_csv("/home/labs/fleishman/jonathaw/elazaridis/" +
                          "fold_and_dock/gpa/mutant_results/" +
                          "experimental_results.tsv", sep='\s+')
    wt_beta_score_file = [a for a in os.listdir(scores_dir)
                          if 'wt' in a and '.score' in a][0]
    wt_beta_df = Rf.score_file2df(scores_dir + '/' + wt_beta_score_file)
    wt_beta_ddg = Rf.get_term_by_threshold(wt_beta_df, 'score', 5, 'a_ddg',
                                           'mean')
    wt_mp_df = Rf.score_file2df('%sall_gpav1_wt_mpframework_25Oct.score' % mp_dir)
    wt_mp_ddg = Rf.get_term_by_threshold(wt_mp_df, 'score', 5, 'a_ddg', 'mean')
    results = {'rs': {}, 'mp': {}}

    for sc_file in [a for a in os.listdir(scores_dir)+os.listdir(mp_dir)
                    if '.score' in a]:
        if 'mpframework' in sc_file:
            df = Rf.score_file2df('%s/%s' % (mp_dir, sc_file))
        else:
            df = Rf.score_file2df('%s/%s' % (scores_dir, sc_file))
        name = sc_file.split('_')[2]
        if '16Mar' in sc_file:
            name = '%s%i%s' % (name[0], int(name[1:-1])+72, name[-1])
        # if name[-1] == 'M': continue
        # threshold = np.percentile(df['score'].values, 5)
        min_ddg = Rf.get_term_by_threshold(df, 'score', 5, 'a_ddg', 'mean')
        if 'mpframework' in sc_file:
            results['mp'][name] = min_ddg
            main_df.set_value(main_df['name'] == name, 'mp', min_ddg-wt_mp_ddg)
        else:
            results['rs'][name] = min_ddg
            main_df.set_value(main_df['name'] == name, 'rs',
                              min_ddg-wt_beta_ddg)

    print(main_df)
    # main_df = main_df.dropna( how='any' )
    args['logger'].log(main_df)

    if args['all4']:
        fig = plt.figure(figsize=(10, 10), facecolor='w')
        i = 1
        for scfxn in ['rs', 'mp']:
            for exp in ['dstbl', 'Doung']:
                ax = plt.subplot(2, 2, i)
                model = linear_model.LinearRegression()
                model.fit(main_df[scfxn].to_frame(), main_df[exp].to_frame())
                line_x = np.linspace(main_df[scfxn].min(), main_df[scfxn].max())
                line_y = model.predict(line_x[:, np.newaxis])
                r2 = r2_score(main_df[exp].values,
                              model.predict(main_df[scfxn].to_frame()))
                plt.scatter(main_df[scfxn], main_df[exp])
                plt.plot(line_x, line_y)
                scfxn_name = 'ResSolv' if scfxn == 'rs' else 'MPFrameWork'
                exp_name = 'Doung 2006' if exp == 'Doung' else r'dsT$\beta$L'
                plt.title('%s Vs. %s' % (scfxn_name, exp_name))
                plt.text(0.8, 0.1, r'$R^2=%.2f$' % r2, fontsize=15,
                         horizontalalignment='center',
                         verticalalignment='center', transform=ax.transAxes)
                plt.axhline(0, color='k')
                plt.axvline(0, color='k')
                if i == 3:
                    plt.xlabel('Rosetta ∆∆G', fontsize=18)
                    plt.ylabel('Experimental ∆∆G', fontsize=18)
                i += 1
        plt.show()

    else:
        fig = plt.figure(facecolor='w')
        ax1 = plt.subplot(1, 2, 1)
        model = linear_model.LinearRegression()
        rs_df = main_df[['name', 'dstbl', 'rs']].dropna(how='any')

        model.fit(rs_df['rs'].to_frame(), rs_df['dstbl'].to_frame())
        line_x = np.linspace(rs_df['rs'].min(), rs_df['rs'].max())
        line_y = model.predict(line_x[:, np.newaxis])
        r2 = r2_score(rs_df['dstbl'].values,
                      model.predict(rs_df['rs'].to_frame()))
        plt.scatter(rs_df['rs'], rs_df['dstbl'])
        plt.plot(line_x, line_y)
        plt.title('%s Vs. %s' % ('ResSolv', r'dsT$\beta$L'))
        plt.text(0.8, 0.1, r'$R^2=%.2f$' % r2, fontsize=15,
                 horizontalalignment='center', verticalalignment='center',
                 transform=ax1.transAxes)
        plt.axhline(0, color='k')
        plt.axvline(0, color='k')
        plt.xlabel('Rosetta ∆∆G', fontsize=18)
        plt.ylabel(r'dsT$\beta$L experimental results', fontsize=18)
        for x, y, n in zip(rs_df['rs'], rs_df['dstbl'], rs_df['name']):
            ax1.annotate(n, (x, y))

        ax2 = plt.subplot(1, 2, 2)
        model = linear_model.LinearRegression()
        mp_df = main_df[['name', 'dstbl', 'mp']].dropna(how='any')
        model.fit(mp_df['mp'].to_frame(), mp_df['dstbl'].to_frame())
        line_x = np.linspace(mp_df['mp'].min(), mp_df['mp'].max())
        line_y = model.predict(line_x[:, np.newaxis])
        r2 = r2_score(mp_df['dstbl'].values,
                      model.predict(mp_df['mp'].to_frame()))
        plt.scatter(mp_df['mp'], mp_df['dstbl'])
        plt.plot(line_x, line_y)
        plt.title('%s Vs. %s' % ('MPFrameWork', r'dsT$\beta$L'))
        plt.text(0.8, 0.1, r'$R^2=%.2f$' % r2, fontsize=15,
                 horizontalalignment='center', verticalalignment='center',
                 transform=ax2.transAxes)
        plt.axhline(0, color='k')
        plt.axvline(0, color='k')
        # plt.xlabel( 'Rosetta ∆∆G', fonctsize=18 )
        # plt.ylabel( r'dsT$\beta$L experimental results', fonctsize=18 )
        plt.show()
        plt.savefig('%s/dsTbL_alone.pdf' % scores_dir)
示例#23
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-sc', type=str, help='score file')
    parser.add_argument('-percent', type=float, default=5, help='percent (1-100) best scoring to get')
    parser.add_argument('-filter', type=str, default='score', help='filter or score term to use')
    parser.add_argument('-num', default=10, type=int, help='use if you want a number of results, not better than percentile')
    parser.add_argument('-mode', default='%')
    parser.add_argument('-over_under', type=str, default='under', help='under/over score should be over/under threshold')
    parser.add_argument('-result', type=str, default=None, help='should the names be written to a file separate from the log file')
    parser.add_argument('-terms', nargs='+', default=['score', 'a_shape', 'a_pack', 'a_ddg', 'res_solv'])
    parser.add_argument('-thresholds', nargs='+', type=float)
    parser.add_argument('-percentile', default=10, type=int)
    args = vars(parser.parse_args())

    logger = Logger('top_%.1f_%s.log' % (args['percent'], args['filter']))

    # read in the score file, determine the threshold for the percentile
    sc_df = Rf.score_file2df(args['sc'])
    score = sc_df[args['filter']]


    if args['mode'] == '%':
        threshold = np.percentile(score, args['percent'])
        logger.log('found a threshold for %f for filter %s to be %.2f' % (args['percent'], args['filter'], threshold))

        # create a df for lines that pass the threshold, either over or above it...
        if args['over_under'] == 'over':
            pass_df = sc_df[sc_df[args['filter']] >= threshold]
        elif args['over_under'] == 'under':
            pass_df = sc_df[sc_df[args['filter']] <= threshold]

    if args['mode'] == 'num':
        sc_df.sort_values(args['filter'], inplace=True)
        pass_df = sc_df.head(args['num'])

    if args['mode'] == 'best_of_best':
        threshold = np.percentile(score, args['percent'])
        sc_df = sc_df[sc_df[args['filter']] <= threshold]
        pass_df = Rf.get_best_of_best(sc_df, args['terms'], args['percentile'])

    if args['mode'] == 'thresholds':
        for term, thrs in zip(args['terms'], args['thresholds']):
            if term in ['a_sasa', 'a_pack', 'a_shape', 'a_tms_span_fa',
                        'a_tms_span', 'a_span_topo']:
                sc_df = sc_df[sc_df[term] > thrs]
            elif term in ['a_mars', 'a_ddg', 'score', 'total_score',
                          'a_res_solv', 'a_span_ins']:
                sc_df = sc_df[sc_df[term] < thrs]
            threshold = np.percentile(score, args['percent'])
            pass_df = sc_df[sc_df[args['filter']] < threshold]

    # output the names (description) of models that pass the threshold, either to the logger file, or to a separate file
    if args['result'] is None:
        logger.create_header('models passing the threshold:')
        for idx, row in pass_df.iterrows():
            logger.log('%s %f' % (row['description'], row['score']), skip_stamp=True)
    else:
        with open(args['result'], 'w+') as fout:
            for name in pass_df['description']:
                fout.write(name + '\n')
    logger.close()
def CaCa_d_by_z(args: dict) -> None:
    df = rf.score_file2df(args["sc"])
    desc = df["description"].str.split("_")

    spline_log = parse_rosetta_log(args)
    spline_log_df = pd.DataFrame({"z": spline_log[0], "d": spline_log[1], "e": spline_log[2]})

    df["d"] = desc.str.get(0).astype(np.float64)
    df["z"] = desc.str.get(1).astype(np.float64)

    ds_sorted = np.round(sorted(list(set(df["d"].values))), 2)
    zs_sorted = np.round(sorted(list(set(df["z"].values))), 2)
    D, Z, E = [], [], []
    model_E = []
    spline_log_E = []
    for d in ds_sorted:
        if d >= 10:
            continue
        for z in zs_sorted:
            # if z > 5: continue
            D.append(d)
            Z.append(z)
            E.append(df[(df["d"] >= d - 0.01) & (df["d"] <= d + 0.01) & (df["z"] == z)]["fa_elec"].values[0] / 4)
            model_E.append(rosetta_dz_model(d, z))
            spline_log_E.append(
                spline_log_df[
                    (spline_log_df["d"] >= d - 0.01) & (spline_log_df["d"] <= d + 0.01) & (spline_log_df["z"] == z)
                ]["e"].values[0]
            )

    args["logger"].log("finished preparing lists")
    master_df = pd.DataFrame({"d": D, "z": Z, "e_caca": E, "model_e": model_E, "spline_log_e": spline_log_E})

    if args["plot_type"] == "1ds":
        args["logger"].log("making 1d plots")
        fig = plt.figure()
        plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.15, hspace=0.45)
        i = 0
        for i, z in enumerate(zs_sorted[::2]):
            plt.subplot(5, 4, 1 + i)
            d_df = master_df[master_df["z"] == z]
            plt.plot(d_df["d"], d_df["e_caca"].values, label="e_caca", c="r")
            plt.plot(d_df["d"], d_df["model_e"].values, label="model_e", c="b")
            plt.plot(d_df["d"], d_df["spline_log_e"].values, label="spline_log_e", c="g")
            plt.title("z=%.2f" % z)
            plt.xlabel("d")
            plt.ylabel("e")
        plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
        plt.show()

    if args["plot_type"] == "contourf":
        D_, Z_ = np.meshgrid(D, Z)

        levels = np.linspace(np.min(E), np.max(E), 1000)

        E_ = sc.interpolate.griddata((D, Z), E, (D_, Z_), method="cubic")
        model_E_ = sc.interpolate.griddata((D, Z), model_E, (D_, Z_), method="cubic")
        spline_log_E_ = sc.interpolate.griddata((D, Z), spline_log_E, (D_, Z_), method="cubic")

        fig = plt.figure()
        plt.subplot(1, 2, 1)
        cs = plt.contourf(D_, Z_, E_, levels=levels, cmap=cm.coolwarm)
        plt.xlabel("d (A)")
        plt.ylabel("z (A)")
        fig.colorbar(cs, format="%.2f")

        plt.subplot(1, 2, 2)
        cs = plt.contourf(D_, Z_, model_E_, levels=levels, cmap=cm.coolwarm)
        plt.xlabel("d (A)")
        plt.ylabel("z (A)")
        fig.colorbar(cs, format="%.2f")
        plt.show()
def quick_rmsd_total(args):

    y_axis_term = 'score'

    sc_df = Rf.score_file2df(args['sc'], args['names'])
    args['logger'].log('found %i structs in sc_df' % len(sc_df))
    pc_df = get_rmsds_from_table(args['pc'])
    args['logger'].log('found %i structs in pc' % len(pc_df))
    a = sc_df.merge(pc_df, on='description')
    sc_df = a.copy()

    if 'a_hha' in sc_df.columns:
        sc_df['angle'] = sc_df['a_hha'] > 0

    args['logger'].log('left with %i in merged df' % len(sc_df))

    args['logger'].log('examining %s with span_topo threshold %f' % (args['sc'], args['span_threshold']))
    fig, ax = plt.subplots()

    if args['best']:
        # sc_df = sc_df[ sc_df['a_span_topo'] >= 0.95 ]
        sc_df_pass = Rf.get_best_of_best(sc_df, args['terms'], args['threshold'])
        sc_df_fail = sc_df[ ~sc_df['description'].isin( sc_df_pass['description'] ) ]
        args['logger'].log('%i models returned from BEST' % len(sc_df_pass))
    else:
        args['logger'].log('total of %i models in score' % len(sc_df))
        sc_df = sc_df[sc_df['a_tms_span_fa'] > 0.5]
        args['logger'].log('%i models pass tms_span' % len(sc_df))
        threshold = np.percentile(sc_df[y_axis_term], args['percent'])
        sc_df = sc_df[ sc_df[y_axis_term] < threshold ]
        args['logger'].log('for percent %.2f found threshold to be %.2f and %i strucutres pass it' % (args['percent'], threshold, len(sc_df)))
        sc_df = sc_df[sc_df['a_shape'] >= 0.6]
        # sc_df = sc_df[sc_df['a_sasa'] > 900]
        sc_df = sc_df[sc_df['a_ddg'] < -6]
        args['logger'].log('%i passed ddg' % len(sc_df))
        # sc_df = sc_df[sc_df['a_pack'] > 0.6]
        sc_df = sc_df[sc_df['a_unsat'] < 1]
        args['logger'].log('%i passed unsat' % len(sc_df))
        sc_df['pass'] = sc_df['a_span_topo'] > args['span_threshold']

        sc_df_pass = sc_df[sc_df['a_span_topo'] > args['span_threshold']]
        args['logger'].log('%i models passed span_topo threshold' % len(sc_df_pass))
        sc_df_fail = sc_df[sc_df['a_span_topo'] <= args['span_threshold']]
        args['logger'].log('%i models failed span_topo threshold' % len(sc_df_fail))

    # ax.scatter(sc_df_fail['rmsd_calc'].values, sc_df_fail['score'].values, color='r', marker='.')

    if 'a_hha' in sc_df.columns:
        ax.scatter(sc_df_pass['pc_rmsd'].values, sc_df_pass[y_axis_term].values, marker='o',
                c=sc_df_pass['a_hha'].values, picker=True, cmap=plt.cm.coolwarm)
    else:
        ax.scatter(sc_df_pass['pc_rmsd'].values, sc_df_pass[y_axis_term].values, marker='o',
                c=sc_df_pass['a_span_topo'].values, picker=True, cmap=plt.cm.coolwarm)

    # min_energy = np.nanmin(list(sc_df_pass['score'].values)+list(sc_df_fail['score'].values))
    min_energy = np.nanmin(list(sc_df_pass[y_axis_term].values))
    max_energy = np.nanmax(list(sc_df_pass[y_axis_term].values))
    plt.ylim([min_energy - 1, max_energy + 1])
    plt.xlim([0, 30])
    plt.title(args['sc']+'_pass')

    # if 'a_hha' in sc_df.columns:
        # ax.scatter(sc_df_fail['pc_rmsd'].values, sc_df_fail[y_axis_term].values, marker='x',
                # c=sc_df_fail['a_hha'].values, picker=True, cmap=plt.cm.coolwarm, s=5, alpha=90)#, markersize=200)
    # else:
        # ax.scatter(sc_df_fail['pc_rmsd'].values, sc_df_fail[y_axis_term].values, marker='x',
                # c=sc_df_fail['a_span_topo'].values, picker=True, cmap=plt.cm.coolwarm, s=5, alpha=90)#, markersize=200)

    # af = PrintLabel(sc_df_pass, 'rmsd_calc', 'score', ['description', 'pass'])
    # fig.canvas.mpl_connect('button_press_event', af)
    point_label_cols = list(set(args['terms'] + ['description', 'a_sasa', 'a_res_solv', 'a_pack', 'a_span_topo', 'a_ddg', 'fa_elec']))
    pl = PointLabel(sc_df_pass, ax, fig, 'pc_rmsd', y_axis_term, point_label_cols, 
                    args['logger']) # a_shape ???
    fig.canvas.mpl_connect('pick_event', pl.onpick)
    # print('for pass')
    # print_best_scores(sc_df_pass, 'score', percentile=0.05)
    # print('for. fail')
    # print_best_scores(sc_df_fail, 'score', percentile=0.05)
    plt.xlabel('RMSD')
    plt.ylabel(y_axis_term)
    if args['show'] == 'show':
        plt.show()
    else:
        plt.savefig('%s.png' % args['sc'].split('.score')[0])
def process_folder(args):
    import time
    pwd = os.getcwd() + '/'
    os.chdir(args['folder'])

    name = pwd.split('/')[-2]
    time = time.strftime("%d.%0-m")

    if not args['force_process']:
        if os.path.isfile('%sall_%s_%s.err' % (pwd, name, time)):
            print('found %sall_%s_%s.err, STOPPING' % (pwd, name, time))
            if args['remove_pdbs']:
                remove_pdbs_only(pwd, name, time)
                sys.exit()
            return 'not finished'
        if not is_folder_finished(pwd):
            return 'not finished'
    sc_files = [a for a in os.listdir(pwd) if a[-3:] == '.sc']
    pdb_files = [a for a in os.listdir(pwd) if a[-7:] == '.pdb.gz']
    err_files = [a for a in os.listdir(pwd) if a[:4] == 'err.']
    job_files = [a for a in os.listdir(pwd) if a[:4] == 'job.']
    print(
        'found a total of %i job files, %i err files, %i pdbs and %i scores' %
        (len(job_files), len(err_files), len(pdb_files), len(sc_files)))
    if len(sc_files) == 0:
        return 'no scores'
    combine_scores('%sall_%s_%s.score' % (pwd, name, time), sc_files)
    non_triv_errs = process_errors('%sall_%s_%s.err' % (pwd, name, time),
                                   err_files)

    if non_triv_errs == 0:
        print('removing out.* and job.*')
        [os.remove(out) for out in os.listdir(pwd) if out[:4] == 'out.']
        [os.remove(job) for job in job_files]
        try:
            os.remove('./command')
        except:
            pass

    run_filters = DCRP.generate_run_filters(
        args={
            'ddg': 24.0,
            'sasa': 1400,
            'shape': 0.6,
            'packstat': 0.6,
            'buried_2': 3,
            'hbonds': -10.
        })
    score_dict = RF.score2dict('%sall_%s_%s.score' % (pwd, name, time))
    passed, failed = DCRP.all_who_pass_run_filters({}, score_dict, run_filters)
    if len(passed) != 0:
        print('there are %i passed scores, so choosing from there' %
              len(list(passed.keys())))
        best_structs = DCRP.best_n_structures(
            {
                'filter': 'ddg',
                'n': min([10, len(list(passed.keys()))])
            }, passed)
    else:
        print('there are no passed, so choosing from the failed')
        best_structs = DCRP.best_n_structures(
            {
                'filter': 'ddg',
                'n': min([10, len(list(failed.keys()))])
            }, failed)
    best_names = [a['description'] for a in best_structs]
    print('the best:', best_names)
    print('removing all other pdbs')
    [os.remove(pdb) for pdb in pdb_files if pdb[:-7] not in best_names]
    os.chdir(pwd)
示例#27
0
def slide_ddg(args):
    global new_df, radio, color_by, picked
    global scat, ax, sliders, sc_df, fig, cm, cbar
    sc_df = Rf.score_file2df(args['sc'], args['names'])
    args['logger'].log('score file has %i entries' % len(sc_df))
    if args['pc'] is not None:
        pc_df = get_rmsds_from_table(args['pc'])
        args['logger'].log('pc file had %i entries' % len(pc_df))
        a = sc_df.merge(pc_df, on='description')
        args['logger'].log('combined there are %i entries' % len(a))
        sc_df = a.copy()

    if args['percent'] != 100:
        threshold = np.percentile(sc_df[args['y']], args['percent'])
        sc_df = sc_df[sc_df[args['y']] < threshold]

    color_by = args['y']
    picked = False

    new_df = sc_df.copy()
    fig, ax = plt.subplots()
    plt.subplots_adjust(left=0.25, bottom=0.25)

    cm = plt.cm.get_cmap('RdYlBu')

    scat = ax.scatter(sc_df[args['x']].values,
                      sc_df[args['y']].values,
                      s=40,
                      cmap=cm,
                      c=sc_df[color_by],
                      picker=True)
    cbar = plt.colorbar(scat)
    sliders = {}
    for i, term in enumerate(args['terms']):
        slider_ax = plt.axes([0.25, 0.01 + i * 0.035, 0.65, 0.03])
        sliders[term] = Slider(slider_ax, term, np.min(sc_df[term].values),
                               np.max(sc_df[term].values), 0)
        sliders[term].on_changed(update)

    ax.set_xlim(
        np.min(new_df[args['x']].values) - 1,
        np.max(new_df[args['x']].values) + 1)
    ax.set_ylim(
        np.min(new_df[args['y']].values) - 1,
        np.max(new_df[args['y']].values) + 1)

    ax.set_xlabel(args['x'])
    ax.set_ylabel(args['y'])

    resetax = plt.axes([0.025, 0.7, 0.15, 0.15])  #[0.8, 0.025, 0.1, 0.04])
    button = Button(resetax,
                    'Reset',
                    color='lightgoldenrodyellow',
                    hovercolor='0.975')
    button.on_clicked(reset)

    printax = plt.axes([0.025, 0.3, 0.15, 0.15])
    printbutton = Button(printax, 'Print', color='green', hovercolor='red')
    printbutton.on_clicked(print_table)

    logax = plt.axes([0.025, 0.1, 0.15, 0.15])
    logbutton = Button(logax, 'log table', color='blue', hovercolor='red')
    logbutton.on_clicked(log_table)

    rax = plt.axes([0.025, 0.5, 0.15, 0.15], axisbg='white')
    radio = RadioButtons(rax, args['terms'], active=0)
    radio.on_clicked(colorfunc)

    # cbar = plt.colorbar(scat)
    pl = PointLabel(new_df, ax, fig, args['x'], args['y'], [
        'description', 'a_sasa', 'a_res_solv', 'a_pack', 'a_span_topo',
        'a_ddg', 'fa_elec'
    ], args['logger'])
    fig.canvas.mpl_connect('pick_event', pl.onpick)

    plt.show()