def spread_filters(args: dict, run_filters: RunFilters, score_dict: dict):
    plt.figure(1)
    col_num = len(list(run_filters.keys()))
    lims = {'sasa': [0, 2000], 'ddg': [-40, 5], 'packstat': [0.0, 1.0], 'shape': [0.0, 1.0], 'score': [-500, 10],
            'rmsd': [0.0, 30.0]}
    for i, flt in enumerate(run_filters.filters.values()):
        plt.subplot(1, col_num, i+1)
        plt.title(flt.filter_type)
        if flt.filter_type == 'buried_2':
            plt.hist(flt.all_seen)
        else:
            plt.boxplot(flt.all_seen)
            plt.ylim(lims[flt.filter_type])
    plt.show()
def all_who_pass_run_filters(args: dict, score_dict: dict, runfilters: RunFilters) -> (dict, dict):
    """
    :param args: run arguments
    :param score_dict: score dictionary
    :param runfilters: run filters
    :return: a list of dict of scores
    """
    passed, failed = {}, {}
    msg = False
    for k, v in score_dict.items():
        res, msg = runfilters.test_all(v)
        if res:
            passed[k] = v
        else:
            failed[k] = v
    if msg:
        print(msg)
    return passed, failed
def find_thresholds_by_rmsd(args):
    """
    :param args: run args
    :return: minimal/maximal (depending on threshold type) of the different filters that will pass all structures with
    RMSD under args['rmsd_threshold'].
    """
    score_dict = score2dict(args['score_file'])
    filter_thresholds = dict(a_score=-100000., a_sasa=100000., a_shape=1.0, total_score=1000000., a_ddg=-100.,
                             a_packstat=1.0, a_buried_2=0)
    passed_rmsd = []
    for name, sc in score_dict.items():
        if sc['rmsd'] <= args['rmsd_threshold']:
            filter_thresholds['a_score'] = max([filter_thresholds['a_score'], sc['score']])
            filter_thresholds['a_sasa'] = min([filter_thresholds['a_sasa'], sc['sasa']])
            filter_thresholds['a_shape'] = min([filter_thresholds['a_shape'], sc['shape']])
            filter_thresholds['total_score'] = min([filter_thresholds['total_score'], sc['total_score']])
            filter_thresholds['a_ddg'] = max([filter_thresholds['a_ddg'], sc['ddg']])
            filter_thresholds['a_packstat'] = min([filter_thresholds['a_packstat'], sc['packstat']])
            filter_thresholds['a_buried_2'] = max([filter_thresholds['a_buried_2'], sc['buried_2']])
            passed_rmsd.append(sc)
    print('found %i scores with rmsd <= %f' % (len(passed_rmsd), args['rmsd_threshold']))
    # print('the old thresholds were:\n%s' % '\n'.join(['%s %f' % (k, v) for k, v in dimer_data().items()]))
    print('the old thresholds were:', generate_run_filters().report())
    print('defined these new filters:\n%s' % '\n'.join(['%s %f' % (k, v) for k, v in filter_thresholds.items()]))
    args['dimer_data'] = filter_thresholds
    run_filters_updated = RunFilters()
    run_filters_updated.append_filter(Filter(name='a_ddg', typ='ddg', threshold=filter_thresholds['a_ddg'],
                                             limits=[-10000, 10000], under_over='under', g_name='$\Delta$$\Delta$G'))
    run_filters_updated.append_filter(Filter(name='a_score', typ='score', threshold=filter_thresholds['a_score'],
                                             limits=[-10000, 10000], under_over='under', g_name='Score'))
    run_filters_updated.append_filter(Filter(name='a_sasa', typ='sasa', threshold=filter_thresholds['a_sasa'],
                                             limits=[0, 100000], under_over='over', g_name='SASA'))
    run_filters_updated.append_filter(Filter(name='a_shape', typ='shape', threshold=filter_thresholds['a_shape'],
                                             limits=[0.0, 1.0], under_over='over', g_name='Shape Complementarity'))
    run_filters_updated.append_filter(Filter(name='a_packstat', typ='packstat',
                                             threshold=filter_thresholds['a_packstat'], limits=[0.0, 1.0],
                                             under_over='over', g_name='PackStat'))
    run_filters_updated.append_filter(Filter(name='a_buried_2', typ='buried_2',
                                             threshold=filter_thresholds['a_buried_2'], limits=[0, 100],
                                             under_over='under', g_name='UnsatisfiedHBonds'))
    run_filters_updated.append_filter(Filter(name='a_rms', typ='rmsd', threshold=1000, limits=[0, 1000],
                                             under_over='under', g_name='RMSD'))
    passed, failed = all_who_pass_run_filters(args, score_dict, run_filters_updated)
    # this_vs_that(args, run_filters_updated, passed, failed, score_dict)
    multiple_plots(args, run_filters_updated, passed, failed, score_dict)
    args['x'], args['y'] = 'rmsd', 'ddg'
    this_vs_that(args, run_filters_updated, passed, failed, score_dict)
    plt.show()
def generate_run_filters(args=None, filter_thresholds=None) -> RunFilters:
    """
    the original thresholds were: ddg -22.5, score -390, sasa 1460, shape 0.623, packstat 0.641, buried_2 3
    :rtype : RunFilters
    """
    if args is None:
        args = {'ddg': -18.0, 'sasa': 1200, 'shape': 0.6, 'packstat': 0.6}
    run_filters = RunFilters()
    if filter_thresholds is None:
        run_filters.append_filter(Filter(name='a_ddg', typ='ddg', threshold=-abs(args['ddg']), limits=[-10000, 10000],
                                         under_over='under', g_name='$\Delta$$\Delta$G'))
        run_filters.append_filter(Filter(name='a_score', typ='score', threshold=-390.0, limits=[-10000, 10000],
                                         under_over='under', g_name='Score'))
        run_filters.append_filter(Filter(name='a_sasa', typ='sasa', threshold=args['sasa'], limits=[0, 100000],
                                         under_over='over', g_name='SASA'))
        run_filters.append_filter(Filter(name='a_shape', typ='shape', threshold=args['shape'], limits=[0.0, 1.0],
                                         under_over='over', g_name='Shape Complementarity'))
        run_filters.append_filter(Filter(name='a_packstat', typ='packstat', threshold=args['packstat'], limits=[0.0, 1.0],
                                         under_over='over', g_name='PackStat'))
        run_filters.append_filter(Filter(name='a_buried_2', typ='buried_2', threshold=args['buried_2'], limits=[0, 100],
                                         under_over='under', g_name='UnsatisfiedHBonds'))
        run_filters.append_filter(Filter(name='a_rms', typ='rmsd', threshold=1000, limits=[0, 1000],
                                         under_over='under', g_name='RMSD'))
        run_filters.append_filter(Filter(name='a_hbonds', typ='hbonds', threshold=-abs(args['hbonds']),
                                         limits=[-1000, 1000], under_over='under', g_name='H. bonds'))
        # run_filters.append_filter(Filter(name='coh_packstat', typ='packstat', threshold=))
    return run_filters