Example #1
0
class PSTHTuningPanel(wx.Panel):
    """ Bar charts of PSTH tuning and instant firing rate.
    """
    def __init__(self, parent, label, name='psth_panel'):
        super(PSTHTuningPanel, self).__init__(parent, -1, name=name)

        self.connected_to_server = True
        self.collecting_data = True
        self.show_errbar_changed = False
        self.show_fitting_changed = False
        self.showing_errbar = False

        self.log_fitting = False
        self.curve_fitting = None
        self.curve_fitter = None
        self.append_data_curve = False
        self.polar_chart = False

        self.hist_bins = []
        self.hist_patches = []
        self.data_curves = 1
        self.data_point_styles = ['g.', 'r.', 'b.']
        self.fitting_curve_styles = ['g-', 'r--', 'b-.']

        self.data = None
        self.psth_data = None
        self.start_data()
        self.raw_data = None

        self.parameter = None
        self.curve_data = None
        self.errbars = None
        self.curve_axes = None

        self.fitting_x = None
        self.fitting_y = None
        self.fitting_data = None

        self.x = None
        self.means = None
        self.stds = None
        self.mono_left_mean = None
        self.mono_left_std = None
        self.mono_right_mean = None
        self.mono_right_std = None
        self.bg_noise_mean = None
        self.mono_dom_mean = None
        self.mono_nod_mean = None

        self.update_data_thread = None

        self.gauss_fitter = None
        self.sinusoid_fitter = None
        self.gabor_fitter = None

        # layout sizer
        box = wx.StaticBox(self, -1, label)
        sizer = wx.StaticBoxSizer(box, wx.VERTICAL)

        # data form
        self.data_form = PSTHTuningDataPanel(self, 'Data form')

        # canvas
        self.dpi = 100
        self.fig = Figure((8.0, 6.0), dpi=self.dpi, facecolor='w')
        self.canvas = FigCanvas(self, -1, self.fig)
        self.make_chart()

        # layout hbox
        hbox = wx.BoxSizer(wx.HORIZONTAL)
        hbox.Add(self.canvas,
                 0,
                 flag=wx.ALL | wx.ALIGN_LEFT | wx.ALIGN_TOP,
                 border=5)
        hbox.Add(self.data_form,
                 0,
                 flag=wx.ALL | wx.ALIGN_RIGHT | wx.ALIGN_TOP,
                 border=5)

        sizer.Add(hbox, 0, wx.ALIGN_CENTRE)
        self.SetSizer(sizer)
        sizer.Fit(self)

        self.update_data_timer = wx.Timer(self, wx.NewId())
        self.Bind(wx.EVT_TIMER, self.on_update_data_timer,
                  self.update_data_timer)
        self.update_data_timer.Start(1000)

    def make_chart(self,
                   data=np.zeros(1),
                   bins=np.arange(10) + 1,
                   polar=False):
        self.polar_chart = polar
        self.hist_bins = []
        self.hist_patches = []
        self.x = np.arange(17)
        self.means = np.zeros(self.x.size)
        self.stds = np.zeros(self.x.size)

        self.fitting_x = np.linspace(self.x[0], self.x[-1], 100, endpoint=True)
        self.fitting_y = np.zeros(self.fitting_x.size)
        self.fig.clear()

        grid = 18
        height = grid // 9
        gs = gridspec.GridSpec(grid, grid)
        # make tuning curve plot
        axes = self.fig.add_subplot(gs[:-height * 2, height // 2:-height // 2],
                                    polar=polar)
        if polar:
            self.curve_data = axes.plot(self.x, self.means, 'ko-')[0]
        else:
            adjust_spines(axes,spines=['left','bottom','right'],spine_outward=['left','right','bottom'],xoutward=10,youtward=30,\
                          xticks='bottom',yticks='both',tick_label=['x','y'],xaxis_loc=5,xminor_auto_loc=2,yminor_auto_loc=2)
            axes.set_ylabel('Response(spikes/sec)', fontsize=12)
            self.curve_data = axes.plot(self.x, self.means,
                                        self.data_point_styles[0])[0]
        self.errbars = axes.errorbar(
            self.x, self.means, yerr=self.stds,
            fmt='k.') if self.showing_errbar else None
        self.curve_axes = axes

        self.fitting_data = axes.plot(self.fitting_x, self.fitting_y,
                                      self.fitting_curve_styles[0])[0]

        axes.set_ylim(0, 100)
        axes.relim()
        axes.autoscale_view(scalex=True, scaley=False)
        axes.grid(b=True, which='major', axis='both', linestyle='-.')
        # make histgrams plot
        rows, cols = (grid - height, grid)
        for row in range(rows, cols)[::height]:
            for col in range(cols)[::height]:
                axes = self.fig.add_subplot(gs[row:row + height,
                                               col:col + height])
                axes.set_axis_bgcolor('white')
                #axes.set_title('PSTH', size=8)
                axes.set_ylim(0, 100)
                if col == 0:
                    adjust_spines(axes,
                                  spines=['left', 'bottom'],
                                  xticks='bottom',
                                  yticks='left',
                                  tick_label=['y'],
                                  xaxis_loc=4,
                                  yaxis_loc=3)
                    axes.set_ylabel('Spikes', fontsize=11)
                elif col == cols - height:
                    adjust_spines(axes,
                                  spines=['right', 'bottom'],
                                  xticks='bottom',
                                  yticks='right',
                                  tick_label=['y'],
                                  xaxis_loc=4,
                                  yaxis_loc=3)
                else:
                    adjust_spines(axes,
                                  spines=['bottom'],
                                  xticks='bottom',
                                  yticks='none',
                                  tick_label=[],
                                  xaxis_loc=4,
                                  yaxis_loc=3)
                pylab.setp(axes.get_xticklabels(), fontsize=8)
                pylab.setp(axes.get_yticklabels(), fontsize=8)
                _n, bins, patches = axes.hist(data,
                                              bins,
                                              facecolor='black',
                                              alpha=1.0)
                self.hist_bins.append(bins)
                self.hist_patches.append(patches)

        self.fig.canvas.draw()

    def set_data(self, data):
        self.data = data

    def update_chart(self, data=None):
        if data is None and self.data is None:
            return
        elif data is None and self.data is not None:
            data = self.data

        selected_unit = wx.FindWindowByName('unit_choice').get_selected_unit()
        if selected_unit is not None:
            channel, unit = selected_unit
            zeroth_psth_data = None
            polar_dict = {
                'orientation': True,
                'spatial_frequency': False,
                'phase': False,
                'disparity': False
            }
            self.parameter = self.psth_data.parameter
            if self.parameter in polar_dict:
                polar_chart = polar_dict[self.parameter]
            else:
                polar_chart = self.polar_chart
            # histogram
            for index in [
                    i for i in data[channel][unit].iterkeys()
                    if (not i & 1 and i < 16)
            ]:
                patch_index = index // 2
                spike_times = data[channel][unit][index]['spikes']
                bins = data[channel][unit][index]['bins']
                psth_data = data[channel][unit][index]['psth_data']
                if index == 0:
                    zeroth_psth_data = psth_data
                _trials = data[channel][unit][index]['trials']
                self.show_fitting_changed = False
                if len(bins) != len(
                        self.hist_bins[patch_index]
                ) or self.show_errbar_changed or polar_chart != self.polar_chart:
                    self.make_chart(spike_times, bins, polar_chart)
                    self.show_errbar_changed = False
                    self.show_fitting_changed = False
                #else:
                for rect, h in zip(self.hist_patches[patch_index], psth_data):
                    rect.set_height(h)

            for index in data[channel][unit].iterkeys():
                mean = data[channel][unit][index]['mean']
                std = data[channel][unit][index]['std']
                if index == -1:
                    self.bg_noise_mean = mean
                    self.bg_noise_std = std
                elif index <= 15:
                    self.means[index] = mean
                    self.stds[index] = std
                elif index == 16:
                    self.mono_left_mean = mean
                    self.mono_left_std = std
                elif index == 17:
                    self.mono_right_mean = mean
                    self.mono_right_std = std

            self.curve_axes.set_xscale('linear')

            if self.parameter == 'orientation':
                self.log_fitting = False
                self.x = np.linspace(0.0, 360.0, 17) / 180 * np.pi
                self.curve_axes.set_title('Orientation Tuning Curve',
                                          fontsize=12)
                if zeroth_psth_data is not None:
                    for rect, h in zip(self.hist_patches[-1],
                                       zeroth_psth_data):
                        rect.set_height(h)
                self.means[-1] = self.means[0]
                self.stds[-1] = self.stds[0]
            if self.parameter == 'spatial_frequency':
                self.log_fitting = True
                self.x = np.logspace(-1.0, 0.5, 16)
                self.curve_axes.set_title('Spatial Frequency Tuning Curve',
                                          fontsize=12)
                self.curve_axes.set_xscale('log')
                self.means = self.means[:len(self.x)]
                self.stds = self.stds[:len(self.x)]
                adjust_spines(self.curve_axes,spines=['left','bottom','right'],spine_outward=['left','right','bottom'],xoutward=10,youtward=30,\
                              xticks='bottom',yticks='both',tick_label=['x','y'],xaxis_loc=5,xminor_auto_loc=2,yminor_auto_loc=2,xmajor_loc=[0.1,0.5,1.0,2.0])
            if self.parameter in ('disparity', 'phase'):
                self.log_fitting = False
                self.x = np.linspace(0.0, 360.0, 17)
                if self.parameter == 'disparity':
                    self.curve_axes.set_title('Disparity Tuning Curve',
                                              fontsize=12)
                if self.parameter == 'phase':
                    self.curve_axes.set_title('Phase Tuning Curve',
                                              fontsize=12)
                if zeroth_psth_data is not None:
                    for rect, h in zip(self.hist_patches[-1],
                                       zeroth_psth_data):
                        rect.set_height(h)
                self.means[-1] = self.means[0]
                self.stds[-1] = self.stds[0]
                if self.mono_left_mean is not None and self.mono_right_mean is not None:
                    #annotate dominant eye activity
                    self.mono_dom_mean = max(self.mono_left_mean,
                                             self.mono_right_mean)
                    self.curve_axes.annotate('',
                                             xy=(360, self.mono_dom_mean),
                                             xytext=(370, self.mono_dom_mean),
                                             arrowprops=dict(facecolor='black',
                                                             frac=1.0,
                                                             headwidth=10,
                                                             shrink=0.05))
                    #annotate non-dominant eye activity
                    self.mono_nod_mean = min(self.mono_left_mean,
                                             self.mono_right_mean)
                    self.curve_axes.annotate('',
                                             xy=(360, self.mono_nod_mean),
                                             xytext=(370, self.mono_nod_mean),
                                             arrowprops=dict(facecolor='gray',
                                                             frac=1.0,
                                                             headwidth=10,
                                                             shrink=0.05))
                if self.bg_noise_mean is not None:
                    #annotate background activity
                    self.curve_axes.annotate('',
                                             xy=(360, self.bg_noise_mean),
                                             xytext=(370, self.bg_noise_mean),
                                             arrowprops=dict(facecolor='white',
                                                             frac=1.0,
                                                             headwidth=10,
                                                             shrink=0.05))

                adjust_spines(self.curve_axes,spines=['left','bottom','right'],spine_outward=['left','right','bottom'],xoutward=10,youtward=30,\
                              xticks='bottom',yticks='both',tick_label=['x','y'],xaxis_loc=5,xminor_auto_loc=2,yminor_auto_loc=2)

            if self.append_data_curve:
                self.curve_axes.plot(
                    self.x, self.means,
                    self.data_point_styles[self.data_curves - 1])
            else:
                self.curve_data.set_xdata(self.x)
                self.curve_data.set_ydata(self.means)
            if self.errbars is not None:
                self._update_errbars(self.errbars, self.x, self.means,
                                     self.stds)

            ##################################################################
            ##### Curve Fitting
            ##################################################################
            if self.log_fitting:
                self.fitting_x = np.logspace(np.log10(self.x[0]),
                                             np.log10(self.x[-1]),
                                             self.fitting_x.size,
                                             endpoint=True)
            else:
                self.fitting_x = np.linspace(self.x[0],
                                             self.x[-1],
                                             self.fitting_x.size,
                                             endpoint=True)

            model_fitting = np.zeros(self.fitting_x.size)
            model_xdata = np.zeros(self.x.size)
            nonzero = np.nonzero(self.means)[0]
            if self.curve_fitting == 'gauss':
                if self.log_fitting:
                    model_xdata, model_fitting = self.curve_fitter.loggaussfit1d(
                        self.x[nonzero], self.means[nonzero], self.fitting_x)
                else:
                    model_xdata, model_fitting = self.curve_fitter.gaussfit1d(
                        self.x[nonzero], self.means[nonzero], self.fitting_x)
            elif self.curve_fitting == 'sin':
                model_xdata, model_fitting = self.curve_fitter.sinusoid1d(
                    self.x[nonzero], self.means[nonzero], self.fitting_x)
            elif self.curve_fitting == 'gabor':
                model_xdata, model_fitting = self.curve_fitter.gaborfit1d(
                    self.x[nonzero], self.means[nonzero], self.fitting_x)

            if self.append_data_curve:
                self.curve_axes.plot(
                    self.fitting_x, model_fitting,
                    self.fitting_curve_styles[self.data_curves - 1])
            else:
                self.fitting_data.set_xdata(self.fitting_x)
                self.fitting_data.set_ydata(model_fitting)

            label = [self.parameter, 'rate', 'std']
            self.data_form.gen_curve_data(self.x, self.means, self.stds,
                                          self.bg_noise_mean,
                                          self.mono_dom_mean,
                                          self.mono_nod_mean, self.fitting_x,
                                          model_fitting, model_xdata, label)
            if self.parameter == 'orientation':
                self.data_form.gen_psth_data(data[channel][unit])
            self.curve_axes.set_xlim(min(self.x), max(self.x))
            self.curve_axes.set_ylim(min(0, min(self.means)),
                                     (max(self.means) * 1.2) // 10 * 10)
            #self.curve_axes.set_ylim(auto=True)
            self.curve_axes.relim()
            self.curve_axes.autoscale_view(scalex=False, scaley=False)

        self.fig.canvas.draw()

    def _update_errbars(self, errbar, x, means, yerrs):
        errbar[0].set_data(x, means)
        # Find the ending points of the errorbars
        error_positions = (x, means - yerrs), (x, means + yerrs)
        # Update the caplines
        for i, pos in enumerate(error_positions):
            errbar[1][i].set_data(pos)
        # Update the error bars
        errbar[2][0].set_segments(
            np.array([[x, means - yerrs], [x, means + yerrs]]).transpose(
                (2, 0, 1)))

    def on_update_data_timer(self, _event):
        if self.collecting_data and self.connected_to_server:
            self.update_data_thread = UpdateDataThread(self, self.psth_data)
            self.update_data_thread.start()

    def start_data(self):
        if self.psth_data is None:
            self.psth_data = TimeHistogram.PSTHTuning()
        self.collecting_data = True
        self.connected_to_server = True

    def stop_data(self):
        self.collecting_data = False
        self.clear_data()
        self.psth_data = None

    def restart_data(self):
        self.stop_data()
        self.start_data()

    def choose_fitting(self, fitting):
        if fitting == 'none':
            self.curve_fitting = None
            self.curve_fitter = None
        if fitting == 'gauss':
            self.curve_fitting = 'gauss'
            self.curve_fitter = GaussFit()
        if fitting == 'sin':
            self.curve_fitting = 'sin'
            self.curve_fitter = SinusoidFit()
        if fitting == 'gabor':
            self.curve_fitting = 'gabor'
            self.curve_fitter = GaborFit()

    def show_errbar(self, checked):
        self.show_errbar_changed = True
        self.showing_errbar = checked

    def open_file(self, path, callback=None):
        self.psth_data = TimeHistogram.PSTHTuning(path)
        data_thread = UpdateFileDataThread(self, self.psth_data, callback)
        data_thread.start()
        self.connected_to_server = False

    def append_data(self, path, callback=None):
        self.append_data_curve = True
        self.data_curves += 1
        self.open_file(path, callback)

    def clear_data(self):
        self.make_chart()
        wx.FindWindowByName('main_frame').unit_choice.clear_unit()
        self.data_form.clear_data()

    def save_data(self):
        data_dict = {}
        data_dict['stimulus'] = self.psth_data.parameter
        data_dict['x'] = self.x
        data_dict['y'] = self.means
        data_dict['data'] = self.data
        return data_dict

    def save_chart(self, path):
        self.canvas.print_figure(path, dpi=self.dpi)
Example #2
0
 def on_update_data_timer(self, _event):
     if self.collecting_data and self.connected_to_server:
         self.update_data_thread = UpdateDataThread(self, self.psth_data)
         self.update_data_thread.start()
Example #3
0
class STAPanel(wx.Panel):
    """ Receptive field plot.
    """
    def __init__(self, parent, label, name='sta_panel'):
        super(STAPanel, self).__init__(parent, -1, name=name)

        self.interpolation_changed = True
        self.show_colorbar_changed = True
        self.show_contour_changed = True
        self.showing_colorbar = True
        self.showing_contour = False
        self.image_fitting = None
        self.image_fitter = None

        # default data type
        self.collecting_data = True
        self.connected_to_server = True
        self.data = None
        self.sta_data = None
        self.psth_data = None
        self.data_type = None
        self.start_data()
        self.update_sta_data_thread = None
        self.update_psth_data_thread = None

        self.axes = None
        self.im = None
        self.img_dim = None
        self.cbar = None

        self.gauss_fitter = None
        self.gabor_fitter = None

        self.peak_time = None
        self.cmap = 'jet'
        # reverse time in ms
        time_slider = 85
        self.time = time_slider / 1000

        self.dpi = 100
        self.fig = Figure((6.0, 6.0), dpi=self.dpi, facecolor='w')
        self.canvas = FigCanvas(self, -1, self.fig)
        self.fig.subplots_adjust(bottom=0.06, left=0.06, right=0.95, top=0.95)

        # popup menu of cavas
        interpolations = ['none', 'nearest', 'bilinear', 'bicubic', 'spline16', 'spline36', 'hanning', 'hamming', 'hermite', \
                               'kaiser', 'quadric', 'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc', 'lanczos']
        self.interpolation = 'nearest'
        self.interpolation_menu = wx.Menu()
        for interpolation in interpolations:
            item = self.interpolation_menu.AppendRadioItem(-1, interpolation)
            # check default interpolation
            if interpolation == self.interpolation:
                self.interpolation_menu.Check(item.GetId(), True)
            self.Bind(wx.EVT_MENU, self.on_interpolation_selected, item)
            wx.FindWindowByName('main_frame').Bind(
                wx.EVT_MENU, self.on_interpolation_selected, item)
        self.popup_menu = wx.Menu()
        self.popup_menu.AppendMenu(-1, '&Interpolation',
                                   self.interpolation_menu)
        self.canvas.Bind(wx.EVT_CONTEXT_MENU, self.on_show_popup)
        wx.FindWindowByName('main_frame').menu_view.AppendSubMenu(
            self.interpolation_menu, '&Interpolation')

        self.make_chart()

        #layout things
        box = wx.StaticBox(self, -1, label)
        sizer = wx.StaticBoxSizer(box, wx.VERTICAL)

        # options
        self.options = OptionPanel(self, 'Options', time=time_slider)
        self.Bind(EVT_TIME_UPDATED, self.on_update_time_slider)
        # results
        self.data_form = STADataPanel(self, 'Results', text_size=(250, 150))

        vbox = wx.BoxSizer(wx.VERTICAL)
        vbox.Add(self.options, 1, wx.TOP | wx.CENTER, 0)
        vbox.Add(self.data_form, 1, wx.TOP | wx.CENTER, 0)

        # canvas
        hbox = wx.BoxSizer(wx.HORIZONTAL)
        hbox.Add(self.canvas,
                 0,
                 flag=wx.ALL | wx.ALIGN_LEFT | wx.ALIGN_TOP,
                 border=5)
        hbox.Add(vbox,
                 0,
                 flag=wx.ALL | wx.ALIGN_RIGHT | wx.ALIGN_TOP,
                 border=5)

        sizer.Add(hbox, 0, wx.ALIGN_CENTRE)
        self.SetSizer(sizer)
        sizer.Fit(self)

        self.update_sta_data_timer = wx.Timer(self, wx.NewId())
        self.Bind(wx.EVT_TIMER, self.on_update_sta_data_timer,
                  self.update_sta_data_timer)
        self.update_sta_data_timer.Start(2000)

        self.update_psth_data_timer = wx.Timer(self, wx.NewId())
        self.Bind(wx.EVT_TIMER, self.on_update_psth_data_timer,
                  self.update_psth_data_timer)
        self.update_psth_data_timer.Start(3000)

    def make_chart(self, img=None):
        self.fig.clear()
        self.axes = self.fig.add_subplot(111)
        if img is None:
            img = np.zeros((32, 32)) + 0.5
            img = self.sta_data.float_to_rgb(img, cmap=self.cmap)
        self.img_dim = img.shape
        self.im = self.axes.imshow(img, interpolation=self.interpolation)
        if self.showing_colorbar:
            self.cbar = self.fig.colorbar(self.im,
                                          shrink=1.0,
                                          fraction=0.045,
                                          pad=0.05,
                                          ticks=[])
        adjust_spines(self.axes,
                      spines=['left', 'bottom'],
                      spine_outward=[],
                      xticks='bottom',
                      yticks='left',
                      tick_label=['x', 'y'])
        #self.axes.autoscale_view(scalex=True, scaley=True)
        self.fig.canvas.draw()

    def set_data(self, data):
        self.data = data

    def update_slider(self, data):
        selected_unit = wx.FindWindowByName('unit_choice').get_selected_unit()
        if selected_unit:
            channel, unit = selected_unit
            peak_time = data[channel][unit][
                'peak_time'] if channel in data and unit in data[
                    channel] else None
            parent = wx.FindWindowByName('main_frame')
            options_panel = parent.chart_panel.options
            auto_time = options_panel.autotime_cb.GetValue()
            if auto_time and peak_time is not None:
                self.peak_time = peak_time
                options_panel.time_slider.SetValue(peak_time)
                evt = wx.CommandEvent(wx.wxEVT_COMMAND_SLIDER_UPDATED)
                evt.SetId(options_panel.time_slider.GetId())
                options_panel.on_slider_update(evt)
                wx.PostEvent(parent, evt)

    def update_chart(self, data=None):
        if data is None and self.data is None:
            return
        elif data is None and self.data is not None:
            data = self.data

        if isinstance(self.sta_data, RevCorr.STAData):
            self.data_type = 'white_noise'
        if isinstance(self.sta_data, RevCorr.ParamMapData):
            self.data_type = 'param_mapping'

        selected_unit = wx.FindWindowByName('unit_choice').get_selected_unit()
        if selected_unit:
            channel, unit = selected_unit
            img = self.sta_data.get_img(data,
                                        channel,
                                        unit,
                                        tau=self.time,
                                        img_format='rgb')
            if self.img_dim != img.shape or self.interpolation_changed or self.show_colorbar_changed \
            or self.showing_contour or self.show_contour_changed:
                self.interpolation_changed = False
                self.show_colorbar_changed = False
                self.show_contour_changed = False
                self.make_chart(img)
                self.img_dim = img.shape

                if self.showing_colorbar:
                    self.cbar.set_ticks([0.0, 0.5, 1.0])
                    if isinstance(self.sta_data, RevCorr.STAData):
                        self.cbar.set_ticklabels(["-1", "0", "1"])
                    if isinstance(self.sta_data, RevCorr.ParamMapData):
                        self.cbar.set_ticklabels(["0.0", "0.5", "1.0"])

            self.data_form.gen_results(self.peak_time)

            if self.image_fitting is not None:
                float_img = self.sta_data.get_img(data,
                                                  channel,
                                                  unit,
                                                  tau=self.time,
                                                  img_format='float')
                if self.image_fitting == 'gauss':
                    params, img = self.image_fitter.gaussfit2d(
                        float_img, returnfitimage=True)
                    level = twodgaussian(params)(params[3] + params[5],
                                                 params[2] + params[4])
                elif self.image_fitting == 'gabor':
                    params, img = self.image_fitter.gaborfit2d(
                        float_img, returnfitimage=True)
                    level = twodgabor(params)(params[3] + params[5],
                                              params[2] + params[4])
                if self.showing_contour:
                    self.axes.contour(img, [level])
                self.data_form.gen_results(self.peak_time, params, img,
                                           self.data_type, self.image_fitting)
                img = self.sta_data.float_to_rgb(img, cmap=self.cmap)

            self.im.set_data(img)
            self.im.autoscale()
            self.canvas.draw()

    def on_update_sta_data_timer(self, _event):
        if self.collecting_data and self.connected_to_server:
            self.update_sta_data_thread = UpdateDataThread(self, self.sta_data)
            self.update_sta_data_thread.start()

    def on_update_psth_data_timer(self, _event):
        if self.collecting_data and self.connected_to_server:
            self.update_psth_data_thread = UpdateDataThread(
                self, self.psth_data)
            self.update_psth_data_thread.start()

    def on_update_time_slider(self, event):
        self.time = event.get_time()
        self.update_chart()

    def start_data(self):
        data_type = wx.FindWindowByName('main_frame').get_data_type()
        if data_type == 'sparse_noise':
            self.sta_data = RevCorr.STAData()
        elif data_type == 'param_map':
            self.sta_data = RevCorr.ParamMapData()
        if self.psth_data is None:
            self.psth_data = TimeHistogram.PSTHAverage()
        self.collecting_data = True
        self.connected_to_server = True

    def stop_data(self):
        self.collecting_data = False
        self.clear_data()
        self.sta_data = None
        self.psth_data = None

    def restart_data(self):
        self.stop_data()
        self.start_data()

    def sparse_noise_data(self):
        self.sta_data = RevCorr.STAData()
        self.restart_data()

    def param_mapping_data(self):
        self.sta_data = RevCorr.ParamMapData()
        self.restart_data()

    def choose_fitting(self, fitting):
        if fitting == 'none':
            self.image_fitting = None
            self.image_fitter = None
        if fitting == 'gauss':
            self.image_fitting = 'gauss'
            self.image_fitter = GaussFit()
        if fitting == 'gabor':
            self.image_fitting = 'gabor'
            self.image_fitter = GaborFit()

    def show_colorbar(self, checked):
        self.show_colorbar_changed = True
        self.showing_colorbar = checked

    def show_contour(self, checked):
        self.show_contour_changed = True
        self.showing_contour = checked

    def on_show_popup(self, event):
        pos = event.GetPosition()
        pos = event.GetEventObject().ScreenToClient(pos)
        self.PopupMenu(self.popup_menu, pos)

    def on_interpolation_selected(self, event):
        item = self.interpolation_menu.FindItemById(event.GetId())
        interpolation = item.GetText()
        if interpolation != self.interpolation:
            self.interpolation_changed = True
        self.interpolation = interpolation
        if hasattr(self, 'data'):
            self.update_chart(self.data)

    def open_file(self, path, callback):
        data_type = wx.FindWindowByName('main_frame').get_data_type()
        if data_type == 'sparse_noise':
            self.sta_data = RevCorr.STAData(path)
        elif data_type == 'param_map':
            self.sta_data = RevCorr.ParamMapData(path)
        self.psth_data = TimeHistogram.PSTHAverage(path)
        UpdateFileDataThread(self, self.sta_data, callback).start()
        UpdateFileDataThread(self, self.psth_data, callback).start()
        self.connected_to_server = False

    def save_data(self):
        data_dict = {}
        data_dict['data_type'] = self.data_type
        data_dict['data'] = self.data
        return data_dict

    def save_chart(self, path):
        self.canvas.print_figure(path, dpi=self.dpi)

    def clear_data(self):
        self.data = None
        self.make_chart()
        wx.FindWindowByName('main_frame').unit_choice.clear_unit()
        self.data_form.clear_data()
Example #4
0
 def on_update_psth_data_timer(self, _event):
     if self.collecting_data and self.connected_to_server:
         self.update_psth_data_thread = UpdateDataThread(self, self.psth_data)
         self.update_psth_data_thread.start()
Example #5
0
class STAPanel(wx.Panel):
    """ Receptive field plot.
    """
    def __init__(self, parent, label, name='sta_panel'):
        super(STAPanel, self).__init__(parent, -1, name=name)
        
        self.interpolation_changed = True
        self.show_colorbar_changed = True
        self.show_contour_changed = True
        self.showing_colorbar = True
        self.showing_contour = False
        self.image_fitting = None
        self.image_fitter = None
        
        # default data type
        self.collecting_data = True
        self.connected_to_server = True
        self.data = None
        self.sta_data = None
        self.psth_data = None
        self.data_type = None
        self.start_data()
        self.update_sta_data_thread = None
        self.update_psth_data_thread = None
        
        self.axes = None
        self.im = None
        self.img_dim = None
        self.cbar = None
        
        self.gauss_fitter = None
        self.gabor_fitter = None
        
        self.peak_time = None
        self.cmap = 'jet'
        # reverse time in ms
        time_slider = 85
        self.time = time_slider/1000
        
        self.dpi = 100
        self.fig = Figure((6.0, 6.0), dpi=self.dpi, facecolor='w')
        self.canvas = FigCanvas(self, -1, self.fig)
        self.fig.subplots_adjust(bottom=0.06, left=0.06, right=0.95, top=0.95)
        
        # popup menu of cavas
        interpolations = ['none', 'nearest', 'bilinear', 'bicubic', 'spline16', 'spline36', 'hanning', 'hamming', 'hermite', \
                               'kaiser', 'quadric', 'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc', 'lanczos']
        self.interpolation = 'nearest'
        self.interpolation_menu = wx.Menu()
        for interpolation in interpolations:
            item = self.interpolation_menu.AppendRadioItem(-1, interpolation)
            # check default interpolation
            if interpolation == self.interpolation:
                self.interpolation_menu.Check(item.GetId(), True)
            self.Bind(wx.EVT_MENU, self.on_interpolation_selected, item)
            wx.FindWindowByName('main_frame').Bind(wx.EVT_MENU, self.on_interpolation_selected, item)
        self.popup_menu = wx.Menu()
        self.popup_menu.AppendMenu(-1, '&Interpolation', self.interpolation_menu)
        self.canvas.Bind(wx.EVT_CONTEXT_MENU, self.on_show_popup)
        wx.FindWindowByName('main_frame').menu_view.AppendSubMenu(self.interpolation_menu, '&Interpolation')
        
        self.make_chart()
        
        #layout things
        box = wx.StaticBox(self, -1, label)
        sizer = wx.StaticBoxSizer(box, wx.VERTICAL)
        
        # options
        self.options = OptionPanel(self, 'Options', time=time_slider)
        self.Bind(EVT_TIME_UPDATED, self.on_update_time_slider)
        # results 
        self.data_form = STADataPanel(self, 'Results', text_size=(250,150))
        
        vbox = wx.BoxSizer(wx.VERTICAL)
        vbox.Add(self.options,1,wx.TOP|wx.CENTER, 0)
        vbox.Add(self.data_form,1,wx.TOP|wx.CENTER, 0)
        
        # canvas 
        hbox = wx.BoxSizer(wx.HORIZONTAL)
        hbox.Add(self.canvas, 0, flag=wx.ALL | wx.ALIGN_LEFT | wx.ALIGN_TOP, border=5)
        hbox.Add(vbox, 0, flag=wx.ALL | wx.ALIGN_RIGHT | wx.ALIGN_TOP, border=5)
        
        sizer.Add(hbox, 0, wx.ALIGN_CENTRE)
        self.SetSizer(sizer)
        sizer.Fit(self)

        self.update_sta_data_timer = wx.Timer(self, wx.NewId())
        self.Bind(wx.EVT_TIMER, self.on_update_sta_data_timer, self.update_sta_data_timer)
        self.update_sta_data_timer.Start(2000)
        
        self.update_psth_data_timer = wx.Timer(self, wx.NewId())
        self.Bind(wx.EVT_TIMER, self.on_update_psth_data_timer, self.update_psth_data_timer)
        self.update_psth_data_timer.Start(3000)
                
    def make_chart(self, img=None):
        self.fig.clear()
        self.axes = self.fig.add_subplot(111)
        if img is None:
            img = np.zeros((32,32)) + 0.5
            img = self.sta_data.float_to_rgb(img,cmap=self.cmap)
        self.img_dim = img.shape
        self.im = self.axes.imshow(img,interpolation=self.interpolation)
        if self.showing_colorbar:
            self.cbar = self.fig.colorbar(self.im, shrink=1.0, fraction=0.045, pad=0.05, ticks=[])
        adjust_spines(self.axes,spines=['left','bottom'],spine_outward=[],
                      xticks='bottom',yticks='left',tick_label=['x','y'])
        #self.axes.autoscale_view(scalex=True, scaley=True)
        self.fig.canvas.draw()
        
    def set_data(self, data):
        self.data = data
        
    def update_slider(self, data):
        selected_unit = wx.FindWindowByName('unit_choice').get_selected_unit()
        if selected_unit:
            channel, unit = selected_unit
            peak_time = data[channel][unit]['peak_time'] if channel in data and unit in data[channel] else None
            parent = wx.FindWindowByName('main_frame')
            options_panel = parent.chart_panel.options
            auto_time = options_panel.autotime_cb.GetValue()
            if auto_time and peak_time is not None:
                self.peak_time = peak_time
                options_panel.time_slider.SetValue(peak_time)
                evt = wx.CommandEvent(wx.wxEVT_COMMAND_SLIDER_UPDATED)
                evt.SetId(options_panel.time_slider.GetId())
                options_panel.on_slider_update(evt)
                wx.PostEvent(parent, evt)
    
    def update_chart(self,data=None):
        if data is None and self.data is None:
            return
        elif data is None and self.data is not None:
            data = self.data
        
        if isinstance(self.sta_data,RevCorr.STAData):
            self.data_type = 'white_noise'
        if isinstance(self.sta_data,RevCorr.ParamMapData):
            self.data_type = 'param_mapping'
        
        selected_unit = wx.FindWindowByName('unit_choice').get_selected_unit()
        if selected_unit:
            channel, unit = selected_unit
            img = self.sta_data.get_img(data, channel, unit, tau=self.time, img_format='rgb')
            if self.img_dim != img.shape or self.interpolation_changed or self.show_colorbar_changed \
            or self.showing_contour or self.show_contour_changed:
                self.interpolation_changed = False
                self.show_colorbar_changed = False
                self.show_contour_changed = False
                self.make_chart(img)
                self.img_dim = img.shape
            
                if self.showing_colorbar:
                    self.cbar.set_ticks([0.0, 0.5, 1.0])
                    if isinstance(self.sta_data, RevCorr.STAData):
                        self.cbar.set_ticklabels(["-1", "0", "1"])
                    if isinstance(self.sta_data, RevCorr.ParamMapData):
                        self.cbar.set_ticklabels(["0.0", "0.5", "1.0"])
            
            self.data_form.gen_results(self.peak_time)
            
            if self.image_fitting is not None:
                float_img = self.sta_data.get_img(data, channel, unit, tau=self.time, img_format='float')
                if self.image_fitting == 'gauss':
                    params,img = self.image_fitter.gaussfit2d(float_img,returnfitimage=True)
                    level = twodgaussian(params)(params[3]+params[5],params[2]+params[4])
                elif self.image_fitting == 'gabor':
                    params,img = self.image_fitter.gaborfit2d(float_img,returnfitimage=True)
                    level = twodgabor(params)(params[3]+params[5],params[2]+params[4])
                if self.showing_contour:
                    self.axes.contour(img, [level])
                self.data_form.gen_results(self.peak_time, params, img, self.data_type, self.image_fitting)
                img = self.sta_data.float_to_rgb(img,cmap=self.cmap)
                
            self.im.set_data(img)
            self.im.autoscale()
            self.canvas.draw()
    
    def on_update_sta_data_timer(self, _event):
        if self.collecting_data and self.connected_to_server:
            self.update_sta_data_thread = UpdateDataThread(self, self.sta_data)
            self.update_sta_data_thread.start()
            
    def on_update_psth_data_timer(self, _event):
        if self.collecting_data and self.connected_to_server:
            self.update_psth_data_thread = UpdateDataThread(self, self.psth_data)
            self.update_psth_data_thread.start()
    
    def on_update_time_slider(self, event):
        self.time = event.get_time()
        self.update_chart()
        
    def start_data(self):
        data_type = wx.FindWindowByName('main_frame').get_data_type()
        if data_type == 'sparse_noise':
            self.sta_data = RevCorr.STAData()
        elif data_type == 'param_map':
            self.sta_data = RevCorr.ParamMapData()
        if self.psth_data is None:
            self.psth_data = TimeHistogram.PSTHAverage()
        self.collecting_data = True
        self.connected_to_server = True
    
    def stop_data(self):
        self.collecting_data = False
        self.clear_data()
        self.sta_data = None
        self.psth_data = None
        
    def restart_data(self):
        self.stop_data()
        self.start_data()
    
    def sparse_noise_data(self):
        self.sta_data = RevCorr.STAData()
        self.restart_data()
    
    def param_mapping_data(self):
        self.sta_data = RevCorr.ParamMapData()
        self.restart_data()
    
    def choose_fitting(self, fitting):
        if fitting == 'none':
            self.image_fitting = None
            self.image_fitter = None
        if fitting == 'gauss':
            self.image_fitting = 'gauss'
            self.image_fitter = GaussFit()
        if fitting == 'gabor':
            self.image_fitting = 'gabor'
            self.image_fitter = GaborFit()
        
    def show_colorbar(self, checked):
        self.show_colorbar_changed = True
        self.showing_colorbar = checked
        
    def show_contour(self, checked):
        self.show_contour_changed = True
        self.showing_contour = checked
        
    def on_show_popup(self, event):
        pos = event.GetPosition()
        pos = event.GetEventObject().ScreenToClient(pos)
        self.PopupMenu(self.popup_menu, pos)
        
    def on_interpolation_selected(self, event):
        item = self.interpolation_menu.FindItemById(event.GetId())
        interpolation = item.GetText()
        if interpolation != self.interpolation:
            self.interpolation_changed = True
        self.interpolation = interpolation
        if hasattr(self, 'data'):
            self.update_chart(self.data)
    
    def open_file(self, path, callback):
        data_type = wx.FindWindowByName('main_frame').get_data_type()
        if data_type == 'sparse_noise':
            self.sta_data = RevCorr.STAData(path)
        elif data_type == 'param_map':
            self.sta_data = RevCorr.ParamMapData(path)
        self.psth_data = TimeHistogram.PSTHAverage(path)
        UpdateFileDataThread(self, self.sta_data, callback).start()
        UpdateFileDataThread(self, self.psth_data, callback).start()
        self.connected_to_server = False
    
    def save_data(self):
        data_dict = {}
        data_dict['data_type'] = self.data_type
        data_dict['data'] = self.data
        return data_dict
        
    def save_chart(self, path):
        self.canvas.print_figure(path, dpi=self.dpi)
        
    def clear_data(self):
        self.data = None
        self.make_chart()
        wx.FindWindowByName('main_frame').unit_choice.clear_unit()
        self.data_form.clear_data()
Example #6
0
class PSTHAveragePanel(wx.Panel):
    """ Bar charts of spiking latency and instant firing rate.
    """
    def __init__(self, parent, label, name='psth_panel'):
        super(PSTHAveragePanel, self).__init__(parent, -1, name=name)
        
        self.connected_to_server = True
        self.collecting_data = True
        self.append_data_curve = False
        self.data_curves = 1
        self.data_point_styles = ['g-','r-','b-']
        
        self.psth_data = TimeHistogram.PSTHAverage()
        self.data = None
        self.raw_data = None
        self.bins = None
        self.bin_data = None
        self.curve_data = None
        self.curve_axes = None
        
        self.update_data_thread = None
        
        # layout sizer
        box = wx.StaticBox(self, -1, label)
        sizer = wx.StaticBoxSizer(box, wx.VERTICAL)
        
        # data form
        self.data_form = PSTHAverageDataPanel(self, 'Data form')
        
        # canvas
        self.dpi = 100
        self.fig = Figure((8.0, 6.0), dpi=self.dpi, facecolor='w')
        self.canvas = FigCanvas(self, -1, self.fig)      
        self.make_chart()
        
        # layout hbox 
        hbox = wx.BoxSizer(wx.HORIZONTAL)
        hbox.Add(self.canvas, 0, flag=wx.ALL | wx.ALIGN_LEFT | wx.ALIGN_TOP, border=5)
        hbox.Add(self.data_form, 0, flag=wx.ALL | wx.ALIGN_RIGHT | wx.ALIGN_TOP, border=5)
        
        sizer.Add(hbox, 0, wx.ALIGN_CENTRE)
        self.SetSizer(sizer)
        sizer.Fit(self)

        self.update_data_timer = wx.Timer(self, wx.NewId())
        self.Bind(wx.EVT_TIMER, self.on_update_data_timer, self.update_data_timer)
        self.update_data_timer.Start(2000)

    def make_chart(self, bins=np.arange(150), bin_data=np.zeros(150)):
        self.bins = bins
        self.bin_data = bin_data
        
        self.fig.clear()
        
        # make Average curve plot
        axes = self.fig.add_subplot(111)
        adjust_spines(axes,spines=['left','bottom','right'],spine_outward=['left','right','bottom'],xoutward=0,youtward=0,\
                      xticks='bottom',yticks='both',tick_label=['x','y'],xaxis_loc=7,xminor_auto_loc=2,yminor_auto_loc=2)
        axes.set_xlabel('Time(ms)',fontsize=12)
        axes.set_ylabel('Response(spikes/sec)',fontsize=12)
        self.curve_data = axes.plot(self.bins, self.bin_data, self.data_point_styles[0])[0]
        self.curve_axes = axes
        
        axes.set_ylim(0,100)
        axes.relim()
        axes.autoscale_view(scalex=False, scaley=False)
        #axes.grid(b=True, which='major',axis='both',linestyle='-.')
                
        self.fig.canvas.draw()
        
    def update_chart(self, data=None):
        if data is None and self.data is None:
            return
        elif data is None and self.data is not None:
            data = self.data
            
        selected_unit = wx.FindWindowByName('unit_choice').get_selected_unit()
        if selected_unit is not None:
            channel, unit = selected_unit
            if channel not in data or unit not in data[channel]:
                return
            #psth_data = data[channel][unit]['psth_data']
            bins = data[channel][unit]['bins']
            smoothed_psth = data[channel][unit]['smoothed_psth']
            maxima_indices = data[channel][unit]['maxima_indices']
            minima_indices = data[channel][unit]['minima_indices']
            
            self.curve_axes.set_xscale('linear')
            if self.append_data_curve:
                self.curve_axes.plot(self.bins, smoothed_psth, self.data_point_styles[self.data_curves-1])
            elif not np.array_equal(self.bins,bins):
                self.make_chart(bins, smoothed_psth)
            else:
                self.curve_data.set_xdata(self.bins)
                self.curve_data.set_ydata(smoothed_psth)
            
            self.data_form.gen_curve_data(self.bins, smoothed_psth, maxima_indices, minima_indices)
            self.curve_axes.set_xlim(min(self.bins),max(self.bins))
            self.curve_axes.set_ylim(auto=True)
            #self.curve_axes.set_ylim(0,100)
            self.curve_axes.relim()
            self.curve_axes.autoscale_view(scalex=False, scaley=True)
            
        self.fig.canvas.draw()
    
    def on_update_data_timer(self, _event):
        if self.collecting_data and self.connected_to_server:
            self.update_data_thread = UpdateDataThread(self, self.psth_data)
            self.update_data_thread.start()
        
    def start_data(self):
        if self.psth_data is None:
            self.psth_data = TimeHistogram.PSTHAverage()
        self.collecting_data = True
        self.connected_to_server = True
    
    def stop_data(self):
        self.collecting_data = False
        self.clear_data()
        self.psth_data = None
        
    def restart_data(self):
        self.stop_data()
        self.start_data()
    
    def smooth_curve(self, checked):
        pass
    
    def open_file(self, path, callback=None):
        self.psth_data = TimeHistogram.PSTHAverage(path)
        data_thread = UpdateFileDataThread(self, self.psth_data, callback)
        data_thread.start()
        self.connected_to_server = False
    
    def append_data(self, path, callback=None):
        self.append_data_curve = True
        self.data_curves += 1
        self.open_file(path, callback)
        
    def clear_data(self):
        self.append_data_curve = False
        self.data_curves = 1
        self.make_chart()
        wx.FindWindowByName('main_frame').unit_choice.clear_unit()
        self.data_form.clear_data()
    
    def save_data(self):
        data_dict = {}
        data_dict['data'] = self.data
        return data_dict
    
    def save_chart(self,path):
        self.canvas.print_figure(path, dpi=self.dpi)
Example #7
0
class PSTHTuningPanel(wx.Panel):
    """ Bar charts of PSTH tuning and instant firing rate.
    """
    def __init__(self, parent, label, name='psth_panel'):
        super(PSTHTuningPanel, self).__init__(parent, -1, name=name)
        
        self.connected_to_server = True
        self.collecting_data = True
        self.show_errbar_changed = False
        self.show_fitting_changed = False
        self.showing_errbar = False
        
        self.log_fitting = False
        self.curve_fitting = None
        self.curve_fitter = None
        self.append_data_curve = False
        self.polar_chart = False
        
        self.hist_bins = []
        self.hist_patches = []
        self.data_curves = 1
        self.data_point_styles = ['g.','r.','b.']
        self.fitting_curve_styles = ['g-','r--','b-.']
        
        self.data = None
        self.psth_data = None
        self.start_data()
        self.raw_data = None
        
        self.parameter = None
        self.curve_data = None
        self.errbars = None
        self.curve_axes = None
        
        self.fitting_x = None
        self.fitting_y = None
        self.fitting_data = None
        
        self.x = None
        self.means = None
        self.stds = None
        self.mono_left_mean = None
        self.mono_left_std = None
        self.mono_right_mean = None
        self.mono_right_std = None
        self.bg_noise_mean = None
        self.mono_dom_mean = None
        self.mono_nod_mean = None
        
        self.update_data_thread = None
        
        self.gauss_fitter = None
        self.sinusoid_fitter = None
        self.gabor_fitter = None
        
        # layout sizer
        box = wx.StaticBox(self, -1, label)
        sizer = wx.StaticBoxSizer(box, wx.VERTICAL)
        
        # data form
        self.data_form = PSTHTuningDataPanel(self, 'Data form')
        
        # canvas
        self.dpi = 100
        self.fig = Figure((8.0, 6.0), dpi=self.dpi, facecolor='w')
        self.canvas = FigCanvas(self, -1, self.fig)      
        self.make_chart()
        
        # layout hbox 
        hbox = wx.BoxSizer(wx.HORIZONTAL)
        hbox.Add(self.canvas, 0, flag=wx.ALL | wx.ALIGN_LEFT | wx.ALIGN_TOP, border=5)
        hbox.Add(self.data_form, 0, flag=wx.ALL | wx.ALIGN_RIGHT | wx.ALIGN_TOP, border=5)
        
        sizer.Add(hbox, 0, wx.ALIGN_CENTRE)
        self.SetSizer(sizer)
        sizer.Fit(self)

        self.update_data_timer = wx.Timer(self, wx.NewId())
        self.Bind(wx.EVT_TIMER, self.on_update_data_timer, self.update_data_timer)
        self.update_data_timer.Start(1000)

    def make_chart(self,data=np.zeros(1),bins=np.arange(10)+1,polar=False):
        self.polar_chart = polar
        self.hist_bins = []
        self.hist_patches = []
        self.x = np.arange(17)
        self.means = np.zeros(self.x.size)
        self.stds = np.zeros(self.x.size)
        
        self.fitting_x = np.linspace(self.x[0], self.x[-1], 100, endpoint=True)
        self.fitting_y = np.zeros(self.fitting_x.size)
        self.fig.clear()
        
        grid = 18
        height = grid // 9
        gs = gridspec.GridSpec(grid, grid)
        # make tuning curve plot
        axes = self.fig.add_subplot(gs[:-height*2,height//2:-height//2],polar=polar)
        if polar:
            self.curve_data = axes.plot(self.x, self.means, 'ko-')[0]
        else:
            adjust_spines(axes,spines=['left','bottom','right'],spine_outward=['left','right','bottom'],xoutward=10,youtward=30,\
                          xticks='bottom',yticks='both',tick_label=['x','y'],xaxis_loc=5,xminor_auto_loc=2,yminor_auto_loc=2)
            axes.set_ylabel('Response(spikes/sec)',fontsize=12)
            self.curve_data = axes.plot(self.x, self.means, self.data_point_styles[0])[0]
        self.errbars = axes.errorbar(self.x, self.means, yerr=self.stds, fmt='k.') if self.showing_errbar else None
        self.curve_axes = axes
        
        self.fitting_data = axes.plot(self.fitting_x, self.fitting_y, self.fitting_curve_styles[0])[0]
        
        axes.set_ylim(0,100)
        axes.relim()
        axes.autoscale_view(scalex=True, scaley=False)
        axes.grid(b=True, which='major',axis='both',linestyle='-.')
        # make histgrams plot
        rows,cols = (grid-height,grid)
        for row in range(rows,cols)[::height]:
            for col in range(cols)[::height]:
                axes = self.fig.add_subplot(gs[row:row+height,col:col+height])
                axes.set_axis_bgcolor('white')
                #axes.set_title('PSTH', size=8)
                axes.set_ylim(0,100)
                if col == 0:
                    adjust_spines(axes,spines=['left','bottom'],xticks='bottom',yticks='left',tick_label=['y'],xaxis_loc=4,yaxis_loc=3)
                    axes.set_ylabel('Spikes',fontsize=11)
                elif col == cols-height:
                    adjust_spines(axes,spines=['right','bottom'],xticks='bottom',yticks='right',tick_label=['y'],xaxis_loc=4,yaxis_loc=3)
                else:
                    adjust_spines(axes,spines=['bottom'],xticks='bottom',yticks='none',tick_label=[],xaxis_loc=4,yaxis_loc=3)
                pylab.setp(axes.get_xticklabels(), fontsize=8)
                pylab.setp(axes.get_yticklabels(), fontsize=8)
                _n, bins, patches = axes.hist(data, bins, facecolor='black', alpha=1.0)
                self.hist_bins.append(bins)
                self.hist_patches.append(patches)
                
        self.fig.canvas.draw()
    
    def set_data(self, data):
        self.data = data
    
    def update_chart(self, data=None):
        if data is None and self.data is None:
            return
        elif data is None and self.data is not None:
            data = self.data
        
        selected_unit = wx.FindWindowByName('unit_choice').get_selected_unit()
        if selected_unit is not None:
            channel, unit = selected_unit
            zeroth_psth_data = None
            polar_dict = {'orientation':True, 'spatial_frequency':False, 'phase':False, 'disparity':False}
            self.parameter = self.psth_data.parameter
            if self.parameter in polar_dict:
                polar_chart = polar_dict[self.parameter]
            else:
                polar_chart = self.polar_chart
            # histogram
            for index in [i for i in data[channel][unit].iterkeys() if (not i&1 and i<16)]:
                patch_index = index // 2
                spike_times = data[channel][unit][index]['spikes']
                bins = data[channel][unit][index]['bins']
                psth_data = data[channel][unit][index]['psth_data']
                if index == 0:
                    zeroth_psth_data = psth_data
                _trials = data[channel][unit][index]['trials']
                self.show_fitting_changed = False
                if len(bins) != len(self.hist_bins[patch_index]) or self.show_errbar_changed or polar_chart != self.polar_chart:
                    self.make_chart(spike_times, bins, polar_chart)
                    self.show_errbar_changed = False
                    self.show_fitting_changed = False
                #else:
                for rect,h in zip(self.hist_patches[patch_index],psth_data):
                    rect.set_height(h)
            
            for index in data[channel][unit].iterkeys():
                mean = data[channel][unit][index]['mean']
                std = data[channel][unit][index]['std']
                if index == -1:
                    self.bg_noise_mean = mean
                    self.bg_noise_std = std
                elif index <= 15:
                    self.means[index] = mean
                    self.stds[index] = std
                elif index == 16:
                    self.mono_left_mean = mean
                    self.mono_left_std = std
                elif index == 17:
                    self.mono_right_mean = mean
                    self.mono_right_std = std
            
            self.curve_axes.set_xscale('linear')
            
            if self.parameter == 'orientation':
                self.log_fitting = False
                self.x = np.linspace(0.0, 360.0, 17)/180*np.pi
                self.curve_axes.set_title('Orientation Tuning Curve',fontsize=12)
                if zeroth_psth_data is not None:
                    for rect,h in zip(self.hist_patches[-1],zeroth_psth_data):
                        rect.set_height(h)
                self.means[-1] = self.means[0]
                self.stds[-1] = self.stds[0]
            if self.parameter == 'spatial_frequency':
                self.log_fitting = True
                self.x = np.logspace(-1.0,0.5,16)
                self.curve_axes.set_title('Spatial Frequency Tuning Curve',fontsize=12)
                self.curve_axes.set_xscale('log')
                self.means = self.means[:len(self.x)]
                self.stds = self.stds[:len(self.x)]
                adjust_spines(self.curve_axes,spines=['left','bottom','right'],spine_outward=['left','right','bottom'],xoutward=10,youtward=30,\
                              xticks='bottom',yticks='both',tick_label=['x','y'],xaxis_loc=5,xminor_auto_loc=2,yminor_auto_loc=2,xmajor_loc=[0.1,0.5,1.0,2.0])
            if self.parameter in ('disparity','phase'):
                self.log_fitting = False
                self.x = np.linspace(0.0, 360.0, 17)
                if self.parameter == 'disparity':
                    self.curve_axes.set_title('Disparity Tuning Curve',fontsize=12)
                if self.parameter == 'phase':
                    self.curve_axes.set_title('Phase Tuning Curve',fontsize=12)
                if zeroth_psth_data is not None:
                    for rect,h in zip(self.hist_patches[-1],zeroth_psth_data):
                        rect.set_height(h)
                self.means[-1] = self.means[0]
                self.stds[-1] = self.stds[0]
                if self.mono_left_mean is not None and self.mono_right_mean is not None:
                    #annotate dominant eye activity
                    self.mono_dom_mean = max(self.mono_left_mean, self.mono_right_mean)
                    self.curve_axes.annotate('', xy=(360, self.mono_dom_mean), xytext=(370, self.mono_dom_mean),
                                            arrowprops=dict(facecolor='black', frac=1.0, headwidth=10, shrink=0.05))
                    #annotate non-dominant eye activity
                    self.mono_nod_mean = min(self.mono_left_mean, self.mono_right_mean)
                    self.curve_axes.annotate('', xy=(360, self.mono_nod_mean), xytext=(370, self.mono_nod_mean),
                                            arrowprops=dict(facecolor='gray', frac=1.0, headwidth=10, shrink=0.05))
                if self.bg_noise_mean is not None:
                    #annotate background activity
                    self.curve_axes.annotate('', xy=(360, self.bg_noise_mean), xytext=(370, self.bg_noise_mean),
                                            arrowprops=dict(facecolor='white', frac=1.0, headwidth=10, shrink=0.05))
                    
                adjust_spines(self.curve_axes,spines=['left','bottom','right'],spine_outward=['left','right','bottom'],xoutward=10,youtward=30,\
                              xticks='bottom',yticks='both',tick_label=['x','y'],xaxis_loc=5,xminor_auto_loc=2,yminor_auto_loc=2)
            
            if self.append_data_curve:
                self.curve_axes.plot(self.x, self.means, self.data_point_styles[self.data_curves-1])
            else:
                self.curve_data.set_xdata(self.x)
                self.curve_data.set_ydata(self.means)
            if self.errbars is not None:
                self._update_errbars(self.errbars,self.x,self.means,self.stds)
            
            ##################################################################
            ##### Curve Fitting
            ##################################################################
            if self.log_fitting:
                self.fitting_x = np.logspace(np.log10(self.x[0]), np.log10(self.x[-1]), self.fitting_x.size, endpoint=True)
            else:
                self.fitting_x = np.linspace(self.x[0], self.x[-1], self.fitting_x.size, endpoint=True)
            
            model_fitting = np.zeros(self.fitting_x.size)
            model_xdata = np.zeros(self.x.size)
            nonzero = np.nonzero(self.means)[0]
            if self.curve_fitting == 'gauss':
                if self.log_fitting:
                    model_xdata,model_fitting = self.curve_fitter.loggaussfit1d(self.x[nonzero], self.means[nonzero], self.fitting_x)
                else:
                    model_xdata,model_fitting = self.curve_fitter.gaussfit1d(self.x[nonzero], self.means[nonzero], self.fitting_x)
            elif self.curve_fitting == 'sin':
                model_xdata,model_fitting = self.curve_fitter.sinusoid1d(self.x[nonzero], self.means[nonzero], self.fitting_x)
            elif self.curve_fitting == 'gabor':
                model_xdata,model_fitting = self.curve_fitter.gaborfit1d(self.x[nonzero], self.means[nonzero], self.fitting_x)
                
            if self.append_data_curve:
                self.curve_axes.plot(self.fitting_x, model_fitting, self.fitting_curve_styles[self.data_curves-1])
            else:
                self.fitting_data.set_xdata(self.fitting_x)
                self.fitting_data.set_ydata(model_fitting)
                
            label = [self.parameter, 'rate', 'std']
            self.data_form.gen_curve_data(self.x, self.means, self.stds,
                                          self.bg_noise_mean, self.mono_dom_mean, self.mono_nod_mean,
                                          self.fitting_x, model_fitting, model_xdata, label)
            if self.parameter == 'orientation':
                self.data_form.gen_psth_data(data[channel][unit])
            self.curve_axes.set_xlim(min(self.x),max(self.x))
            self.curve_axes.set_ylim(min(0, min(self.means)), (max(self.means)*1.2)//10*10)
            #self.curve_axes.set_ylim(auto=True)
            self.curve_axes.relim()
            self.curve_axes.autoscale_view(scalex=False, scaley=False)
                
        self.fig.canvas.draw()
    
    def _update_errbars(self, errbar, x, means, yerrs):
        errbar[0].set_data(x,means)
        # Find the ending points of the errorbars
        error_positions = (x,means-yerrs), (x,means+yerrs)
        # Update the caplines
        for i,pos in enumerate(error_positions):
            errbar[1][i].set_data(pos)
        # Update the error bars
        errbar[2][0].set_segments(np.array([[x, means-yerrs], [x, means+yerrs]]).transpose((2, 0, 1)))
    
    def on_update_data_timer(self, _event):
        if self.collecting_data and self.connected_to_server:
            self.update_data_thread = UpdateDataThread(self, self.psth_data)
            self.update_data_thread.start()
        
    def start_data(self):
        if self.psth_data is None:
            self.psth_data = TimeHistogram.PSTHTuning()
        self.collecting_data = True
        self.connected_to_server = True
    
    def stop_data(self):
        self.collecting_data = False
        self.clear_data()
        self.psth_data = None
        
    def restart_data(self):
        self.stop_data()
        self.start_data()
        
    def choose_fitting(self, fitting):
        if fitting == 'none':
            self.curve_fitting = None
            self.curve_fitter = None
        if fitting == 'gauss':
            self.curve_fitting = 'gauss'
            self.curve_fitter = GaussFit()
        if fitting == 'sin':
            self.curve_fitting = 'sin'
            self.curve_fitter = SinusoidFit()
        if fitting == 'gabor':
            self.curve_fitting = 'gabor'
            self.curve_fitter = GaborFit()
    
    def show_errbar(self, checked):
        self.show_errbar_changed = True
        self.showing_errbar = checked
    
    def open_file(self, path, callback=None):
        self.psth_data = TimeHistogram.PSTHTuning(path)
        data_thread = UpdateFileDataThread(self, self.psth_data, callback)
        data_thread.start()
        self.connected_to_server = False
    
    def append_data(self, path, callback=None):
        self.append_data_curve = True
        self.data_curves += 1
        self.open_file(path, callback)
        
    def clear_data(self):
        self.make_chart()
        wx.FindWindowByName('main_frame').unit_choice.clear_unit()
        self.data_form.clear_data()
    
    def save_data(self):
        data_dict = {}
        data_dict['stimulus'] = self.psth_data.parameter
        data_dict['x'] = self.x
        data_dict['y'] = self.means
        data_dict['data'] = self.data
        return data_dict
    
    def save_chart(self,path):
        self.canvas.print_figure(path, dpi=self.dpi)