class TestVisualisation(unittest.TestCase):
    def setUp(self):
        self.space = ParameterSpace(PSlice(300), PSlice(6), PSlice(2.), PSlice(4), PSlice(5.), PSlice(.1, 1.0, .1), PSlice(-30., 5., 5.), PSlice(120), PSlice(30), PSlice(10, 50, 10), PSlice(10), PSlice(20), PSlice(200), PSlice(40), PSlice(0.), PSlice(0), PSlice(5.), PSlice(2.))
        self.space.load_analysis_results()
        self.p = self.space.get_nontrivial_subspace(('noise_rate_mu', 40), ('bias', -20), ('active_mf_fraction', 0.5)).item(0)        
    def test_heatmap(self):
        fig, ax = self.space.get_nontrivial_subspace(('noise_rate_mu', 40)).plot_2d_heatmap('point_mi_qe')
        self.assertEqual(len(ax.get_images()), 1)
    def test_mi_detail_precision(self):
        midp = MIDetailPlotter(point=self.p, corrections=('plugin', 'qe'), fig_title='test', label_prefix='nm50_b-20')
        fig, ax = midp.plot()
        self.assertEqual(len(ax.get_lines()), 4)
    def test_mi_detail_size(self):
        midp = MIDetailPlotter(point=self.p, corrections=('plugin', 'qe'), fig_title='test', label_prefix='nm50_b-20')
        fig, ax = midp.plot(mode='alphabet_size')
        self.assertEqual(len(ax.get_lines()), 4)
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from matplotlib import pyplot as plt
import numpy as np

from parameters import PSlice, ParameterSpace, ParameterSpacePoint
from visualisation import InteractiveHeatmap

space = ParameterSpace(PSlice(300), PSlice(6), PSlice(2.), PSlice(4), PSlice(5.), PSlice(.6), PSlice(-10.), PSlice(120), PSlice(30), PSlice(10), PSlice(10), PSlice(20), PSlice(50,210,30), PSlice(40), PSlice(0.), PSlice(0), PSlice(5.), PSlice(2.))
space.load_analysis_results()


space_1 = space
space_2 = ParameterSpace(PSlice(300), PSlice(6), PSlice(2.), PSlice(4), PSlice(5.), PSlice(.6), PSlice(-10.), PSlice(120), PSlice(30), PSlice(10), PSlice(10), PSlice(20), PSlice(500,900,300), PSlice(40), PSlice(0.), PSlice(0), PSlice(5.), PSlice(2.))
space_2.load_analysis_results()

space_3 = np.hstack([space_1.get_nontrivial_subspace(('noise_rate_mu', 10)), space_2.get_nontrivial_subspace(('noise_rate_mu', 10))])


corrections = ['plugin', 'bootstrap', 'qe', 'pt', 'nsb']
values = np.vstack([getattr(x, 'ts_decoded_mi_{0}'.format(correction), np.nan)[0:200] for correction in corrections for x in space_3.flat])
labels = ['{0}, {1}'.format(correction, getattr(x, 'n_trials') - getattr(x,'training_size')) for correction in corrections for x in space_3.flat]

fig,ax = plt.subplots()
print values.shape
plot = ax.imshow(values, interpolation='none', aspect='auto', cmap='coolwarm', origin='lower')
ax.set_xticks([20]+ax.get_xticks())
ax.set_xlim(0,values.shape[1]-1)
# ax.set_xticklabels(['10', '40', '70', '100', '130', '160', '460', '760'])
ax.set_yticks(range(values.shape[0]))
ax.set_yticklabels(labels)