best_descs = ['haddad_desc', 'all', 'vib_100']
nice_names = {'none_opt': 'default parameters (q2)',
              'k_opt': 'k_best optimized (q2)',
              'reg_opt': 'C optimized (q2)',
              'both_opt': 'both optimized (q2)',
              }
line_max = 0.9
ticks = [0, 0.2, 0.4, 0.6, 0.8, 0.9]
ticklabels = ['0', '.2', '.4', '.6', '.8', '']
vals = [out_res[k] for k in best_descs]
reference_name = 'none_opt'
to_compare = ['k_opt', 'reg_opt']
titles = ['c)', 'd)']
for i, pick_type in enumerate(to_compare):

    reference = np.array(utils.flatten([r[reference_name] for r in vals]))
    improved = np.array(utils.flatten([r[pick_type] for r in vals]))
    # only plot values which reach genscore > 0 after paramsearch
    reference = reference[improved > 0]
    improved = improved[improved > 0]
    # don't care how negative it was before search, all < 0 are equally bad
    reference[reference < 0] = 0
    improved[improved < 0] = 0

#    fig = plt.figure(figsize=(2.23, 2))
    ax = plt.subplot(gs[0,i])#fig.add_subplot(1, 3, i+1)
    ax.plot(reference, improved, 'ko', alpha=0.6, markersize=4)
    ax.plot([0, line_max-0.05], [0, line_max-0.05], color='0.5')
    ax.set_ylabel(nice_names[pick_type])
    plt.axis('scaled')
    ax.set_xlim([-0.05, line_max])
fig = plt.figure()
plot_res = {desc:{'ps': [], 'corrs': []} for desc in to_compare}
for desc in to_compare:

    for glom in res[reference]:
        corr, p = stats.pearsonr(res[reference][glom]['predictions'],
                                 res[desc][glom]['predictions'])
        plot_res[desc]['corrs'].append(corr)
        plot_res[desc]['ps'].append(p)

for desc, pres in plot_res.items():
    print '{} mean r: {:.2f}'.format(desc, np.mean(pres['corrs']))


ax = fig.add_subplot(1,2,1)
ax.hist(utils.flatten([v['ps'] for v in plot_res.values()]))

fig = plt.figure(figsize=(3.35, 1.8))
marker = ['o', 's', 'd']
xticks = [0, 0.2, 0.4, 0.6, 0.8]
xticklabels = ['0', '.2', '.4', '.6', '.8']
ax = fig.add_subplot(1,2,1)
ax.plot([-0.05, 0.8], [-0.05, 0.8], color='0.6')
col_ref = []
col_comp = []
for i, (desc, pres) in enumerate(plot_res.items()):
    compare_scores = np.array([res[desc][g]['score'] for g in res[desc]])
    ref_scores = np.array([res[reference][g]['score'] for g in res[reference]])
    compare_scores[compare_scores < 0] = 0
    ref_scores[ref_scores < 0] = 0
    col_ref.extend(ref_scores)
        chosen = sorted_tmp[0][0]
        res.append(sorted_tmp[0])
    return res, chosen


val_res = []
chosens = []
for _ in range(100):
    kf = KFold(rm.shape[0], 5, indices=False, shuffle=True)
    for train, test in kf:
        measure = lambda x: stats.pearsonr(pdist(rm[train], 'correlation'), pdist(x))[0]
        greedy_res, chosen = greedy_selection(eva_space[train], measure)
        sorted_res = sorted(greedy_res, key=lambda t: t[1], reverse=True)
        best_chosen = sorted_res[0][0]
        chosens.append(best_chosen)
        perf = stats.pearsonr(pdist(rm[test], 'correlation'), pdist(eva_space[np.ix_(test, best_chosen)]))[0]
        val_res.append(perf)


fig = plt.figure()
ax = fig.add_subplot(211)
ax.hist(val_res, color='0.5')
ax.set_xlabel('histogram over r, mean: {}'.format(np.mean(val_res)))

ax = fig.add_subplot(212)
ax.hist(utils.flatten(chosens), bins=range(eva_space.shape[1]+1), color='0.5')
ax.set_xlabel('dimension selection histogram')
fig.subplots_adjust(hspace=0.2)