class TestVisualisation(unittest.TestCase):
    def setUp(self):
        self.space = ParameterSpace(PSlice(300), PSlice(6), PSlice(2.), PSlice(4), PSlice(5.), PSlice(.1, 1.0, .1), PSlice(-30., 5., 5.), PSlice(120), PSlice(30), PSlice(10, 50, 10), PSlice(10), PSlice(20), PSlice(200), PSlice(40), PSlice(0.), PSlice(0), PSlice(5.), PSlice(2.))
        self.space.load_analysis_results()
        self.p = self.space.get_nontrivial_subspace(('noise_rate_mu', 40), ('bias', -20), ('active_mf_fraction', 0.5)).item(0)        
    def test_heatmap(self):
        fig, ax = self.space.get_nontrivial_subspace(('noise_rate_mu', 40)).plot_2d_heatmap('point_mi_qe')
        self.assertEqual(len(ax.get_images()), 1)
    def test_mi_detail_precision(self):
        midp = MIDetailPlotter(point=self.p, corrections=('plugin', 'qe'), fig_title='test', label_prefix='nm50_b-20')
        fig, ax = midp.plot()
        self.assertEqual(len(ax.get_lines()), 4)
    def test_mi_detail_size(self):
        midp = MIDetailPlotter(point=self.p, corrections=('plugin', 'qe'), fig_title='test', label_prefix='nm50_b-20')
        fig, ax = midp.plot(mode='alphabet_size')
        self.assertEqual(len(ax.get_lines()), 4)
import sys, os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))
from parameters import ParameterSpace
from scipy.optimize import fmin_bfgs
import numpy as np


np.random.seed(1234)

p = ParameterSpace()
p.add_parameter(0, 'b')
p['b'].set_bounds(-10, 10000)

x = np.random.randn(1000)

y = 2*x + np.random.randn(1000)

def squares(params,p,y,x):

    print(params)
    p.update(params)
    print(p)
    return np.square(y - p['b'].value*x).sum()

init = p['b'].value_

fmin = fmin_bfgs(squares, x0=init, args=(p,y,x))

p.summary()
 def setUp(self):
     self.space = ParameterSpace(PSlice(300), PSlice(6), PSlice(2.), PSlice(4), PSlice(5.), PSlice(.1, 1.0, .1), PSlice(-30., 5., 5.), PSlice(120), PSlice(30), PSlice(10, 50, 10), PSlice(10), PSlice(20), PSlice(200), PSlice(40), PSlice(0.), PSlice(0), PSlice(5.), PSlice(2.))
     self.space.load_analysis_results()
     self.p = self.space.get_nontrivial_subspace(('noise_rate_mu', 40), ('bias', -20), ('active_mf_fraction', 0.5)).item(0)        
#!/usr/bin/python
from parameters import ParameterSpace, ParameterSet


if __name__=='__main__':
  from argparse import ArgumentParser
  parser = ArgumentParser()
  parser.add_argument('template', nargs='?', default='template.sh')
  parser.add_argument('tasksPerNode', nargs='?', default=16, type=int)
  arguments = parser.parse_args()

  #Set up parameter space
  parameterSpace = ParameterSpace()

  f=5

  for reduceReductions in [0, 1]:
    for processes in [1, 10, 11, 82, 83, 11+81, 730, 11+81+729 ]: #, 6562]:
      forkLevelIncrement = 1
      if(processes == 82 or processes == 83 or processes == 730):
        forkLevelIncrement = 2


      psProcesses = ParameterSet(processes=processes, tasks_per_node=arguments.tasksPerNode, forkLevelIncrement=forkLevelIncrement, reduceReductions=reduceReductions)

      #6x6 Patches
      ps6x6PatchSize = psProcesses.derive(patchSize=6)
      parameterSpace.addParameterSet(ps6x6PatchSize.derive(gridSize=162, tFinal=1, useHeapCompression=1))
      parameterSpace.addParameterSet(ps6x6PatchSize.derive(gridSize=486, tFinal=f, useHeapCompression=1))
      parameterSpace.addParameterSet(ps6x6PatchSize.derive(gridSize=1458, tFinal=f*f, useHeapCompression=1))
      parameterSpace.addParameterSet(ps6x6PatchSize.derive(gridSize=4374, tFinal=f*f*f, useHeapCompression=1))
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from matplotlib import pyplot as plt
import numpy as np

from parameters import PSlice, ParameterSpace, ParameterSpacePoint
from visualisation import InteractiveHeatmap

space = ParameterSpace(PSlice(300), PSlice(6), PSlice(2.), PSlice(4), PSlice(5.), PSlice(.6), PSlice(-10.), PSlice(120), PSlice(30), PSlice(10), PSlice(10), PSlice(20), PSlice(50,210,30), PSlice(40), PSlice(0.), PSlice(0), PSlice(5.), PSlice(2.))
space.load_analysis_results()


space_1 = space
space_2 = ParameterSpace(PSlice(300), PSlice(6), PSlice(2.), PSlice(4), PSlice(5.), PSlice(.6), PSlice(-10.), PSlice(120), PSlice(30), PSlice(10), PSlice(10), PSlice(20), PSlice(500,900,300), PSlice(40), PSlice(0.), PSlice(0), PSlice(5.), PSlice(2.))
space_2.load_analysis_results()

space_3 = np.hstack([space_1.get_nontrivial_subspace(('noise_rate_mu', 10)), space_2.get_nontrivial_subspace(('noise_rate_mu', 10))])


corrections = ['plugin', 'bootstrap', 'qe', 'pt', 'nsb']
values = np.vstack([getattr(x, 'ts_decoded_mi_{0}'.format(correction), np.nan)[0:200] for correction in corrections for x in space_3.flat])
labels = ['{0}, {1}'.format(correction, getattr(x, 'n_trials') - getattr(x,'training_size')) for correction in corrections for x in space_3.flat]

fig,ax = plt.subplots()
print values.shape
plot = ax.imshow(values, interpolation='none', aspect='auto', cmap='coolwarm', origin='lower')
ax.set_xticks([20]+ax.get_xticks())
ax.set_xlim(0,values.shape[1]-1)
# ax.set_xticklabels(['10', '40', '70', '100', '130', '160', '460', '760'])
ax.set_yticks(range(values.shape[0]))
ax.set_yticklabels(labels)
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from matplotlib import pyplot as plt
plt.ion()

from parameters import PSlice, ParameterSpace
from visualisation import InteractiveHeatmap

space = ParameterSpace(PSlice(300), PSlice(6), PSlice(2.),
                       PSlice(4), PSlice(5.), PSlice(.1, 1., .1),
                       PSlice(-30.,5.,5.), PSlice(120), PSlice(30),
                       PSlice(10), PSlice(10), PSlice(20),
                       PSlice(200), PSlice(40), PSlice(0.), PSlice(0),
                       PSlice(5.), PSlice(2.))
space.load_analysis_results()


ihm = InteractiveHeatmap(
    space.get_nontrivial_subspace(('noise_rate_mu', 10)),
    alpha=1)
ihm.plot('o_synchrony', invert_axes=False)

ihm.ax.set_xlabel('inhibitory current (pA)', fontsize=16)
#ihm.ax.set_xlabel('number of granule cell dendrites', fontsize=16)
ihm.ax.set_ylabel('"on" mossy terminals fraction', fontsize=16)
ihm.cbar.set_label('output layer synchrony', fontsize=16)

plt.show()
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter

from parameters import PSlice, ParameterSpace

def scale_nspikes_to_freq(nspikes,pos):
    return int(round(10*nspikes/3.))

space = ParameterSpace(PSlice(300), PSlice(6), PSlice(2.), PSlice(4), PSlice(5.), PSlice(.1, 1., .1), PSlice(-30., 5., 5.), PSlice(120), PSlice(30), PSlice(10,80,10), PSlice(10), PSlice(20), PSlice(200), PSlice(40), PSlice(0.), PSlice(0), PSlice(5.), PSlice(2.))
space.load_analysis_results()

b = np.concatenate([point.o_level_array.flatten() for point in space.get_nontrivial_subspace(('bias', 0)).flat])

fig, ax = plt.subplots()
ax.hist(b)
ax.xaxis.set_major_formatter(FuncFormatter(scale_nspikes_to_freq))
ax.yaxis.get_major_formatter().set_powerlimits((-3,5))
ax.set_xlabel('Firing rate (Hz)')
ax.set_ylabel('Observations')

for loc, spine in ax.spines.iteritems():
    if loc in ['right','top']:
        spine.set_color('none')
# turn off ticks where there is no spine
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')    

plt.show()