예제 #1
0
def fill_between_downsampled(ax, x, y1, y2, N_plot_points=4000, **kwargs):

    x = np.asarray(x)
    y1 = np.asarray(y1)
    y2 = np.asarray(y2)

    new_x = np.linspace(x.min(), x.max(), N_plot_points)
    new_y1 = interpolation_utils.in_ex_polate(x, y1, new_x)
    new_y2 = interpolation_utils.in_ex_polate(x, y2, new_x)
    ax.fill_between(new_x, new_y1, new_y2, **kwargs)
예제 #2
0
def __plot_data(rec_data, rec_time_plot, rec_stim, original_model_output,
                rate_name, Vm_name):
    fig, axs = plt.subplots(3, 1, figsize=(12, 4), sharex=True)

    axs[0].plot(rec_time_plot,
                math_utils.normalize(rec_stim),
                label='This model')
    axs[0].plot(rec_time_plot,
                interpolation_utils.in_ex_polate(
                    x_old=original_model_output['Stimulus']['Time'],
                    y_old=original_model_output['Stimulus']['Stim'],
                    x_new=rec_time_plot),
                '--',
                label='Target')

    rate_data = rec_data[rate_name]
    if rate_data.values.ndim > 1: rate_data = rate_data.mean(axis=1)

    axs[1].plot(rec_time_plot, rate_data)
    axs[1].plot(original_model_output['Time'], original_model_output['rate'],
                '--')

    Vm_data = rec_data[Vm_name]
    if Vm_data.values.ndim > 1: Vm_data = Vm_data.mean(axis=1)

    axs[2].plot(rec_time_plot, Vm_data)
    axs[2].plot(original_model_output['Time'], original_model_output['Vm'],
                '--')

    axs[0].legend()
    plt.tight_layout()
    plt.show()
    def calc_loss(self, rec_data, plot=False):

        loss = {}

        for f0, rec_data_f in rec_data.items():
            for V0, rec_data_f_i in rec_data_f.items():

                if rec_data_f_i is None:
                    loss['f0=' + str(f0) + '_V0' +
                         str(V0)] = self.max_loss['single']

                else:

                    trace_time = rec_data_f_i['Time']
                    trace = rec_data_f_i['Current']

                    target_time = self.target[f0][V0]['Time']
                    target = self.target[f0][V0]['Current']

                    loss_time = np.linspace(self.t_drop, target_time.max(),
                                            1000)

                    trace = interpolation_utils.in_ex_polate(x_old=trace_time,
                                                             y_old=trace,
                                                             x_new=loss_time)
                    target = interpolation_utils.in_ex_polate(
                        x_old=target_time, y_old=target, x_new=loss_time)

                    trace_loss = np.mean(((trace - target) / target.max())**2)

                    if trace_loss >= self.max_loss['single']:
                        trace_loss = self.max_loss['single']

                    loss['f0=' + str(f0) + '_V0' + str(V0)] = trace_loss

                    if plot:
                        plt.figure(figsize=(12, 1))
                        plt.title('f0=' + str(f0) + '_' + str(V0) +
                                  " --> loss = {:.4g}".format(trace_loss))
                        plt.plot(loss_time, trace, label='trace')
                        plt.plot(loss_time, target, label='target')
                        plt.legend()
                        plt.show()

        loss['total'] = np.sum(list(loss.values()))

        return loss
예제 #4
0
def plot_downsampled(ax, x, y, N_plot_points=4000, **kwargs):
    x = np.asarray(x)
    y = np.asarray(y)

    if x.size < N_plot_points:
        ax.plot(x, y, **kwargs)
    else:
        new_x = np.linspace(np.min(x), np.max(x), N_plot_points)
        new_y = interpolation_utils.in_ex_polate(x, y, new_x)
        ax.plot(new_x, new_y, **kwargs)
예제 #5
0
    def rate2best_iGluSnFR_trace(self, trace):

        iGluSnFR_trace = iGluSnFR_utils.rate2iGluSnFR(trace,
                                                      rec_time=self.rec_time,
                                                      n_drop=self.n_drop_trace)

        intpol_iGluSnFR_trace = interpolation_utils.in_ex_polate(
            x_old=self.rec_time, y_old=iGluSnFR_trace, x_new=self.target_time)

        trans_iGluSnFR_trace, iGluSnFR_loss = lin_trans_utils.best_lin_trans(
            trace=intpol_iGluSnFR_trace,
            target=self.target,
            loss_fun=self.compute_iGluSnFR_trace_loss)

        return trans_iGluSnFR_trace, iGluSnFR_loss
예제 #6
0
    def __compute_total_charge_var_dur(self,
                                       param_x,
                                       params,
                                       return_charge=True,
                                       return_anchors=False):

        params = np.asarray(params).flatten()
        assert params.size == self.n_params

        assert (params[-1] >= 0) and (params[-1] <= 1)

        stim_time_max = self.stim_time.max() - self.predur - self.postdur
        assert stim_time_max > self.stim_time_min

        subtract_stim_time = self.postdur + (1 - params[-1]) * (
            stim_time_max - self.stim_time_min)

        post_dur_idxs = np.where(
            (self.stim_time - self.stim_time.max()) > -subtract_stim_time)[0]

        if post_dur_idxs.size == 0:
            idx_stop_stim = self.stim_time.size
        else:
            idx_stop_stim = post_dur_idxs[0]

        stim_anchor_idxs = np.linspace(self.idx_start_stim,
                                       idx_stop_stim - 1,
                                       params.size + 2,
                                       dtype='int')
        stim_anchor_points_time = self.stim_time[stim_anchor_idxs]
        stim_anchor_points_amp = np.concatenate(
            [np.zeros(1), params[:-1], param_x * np.ones(1),
             np.zeros(1)])

        stim = interpolation_utils.in_ex_polate(
            x_old=stim_anchor_points_time,
            y_old=stim_anchor_points_amp,
            x_new=self.stim_time,
            kind=self.spline_mode,
        )

        if return_charge:
            return np.abs(np.sum(stim))
        elif return_anchors:
            return stim, stim_anchor_points_time, stim_anchor_points_amp
        else:
            return stim
예제 #7
0
    def __create_stim_from_params(self, params):

        params = np.asarray(params).flatten()
        assert params.size == self.n_params + 1, str(params.size) + '!=' + str(
            self.n_params)

        # Get indexes to create stimulus with parameters. Add two zeros.
        stim_anchor_idxs = np.linspace(self.idx_start_stim,
                                       self.idx_stop_stim - 1,
                                       params.size + 2,
                                       dtype='int')
        stim_anchor_points_time = self.stim_time[stim_anchor_idxs]
        stim_anchor_points_amp = np.concatenate(
            [np.zeros(1), params, np.zeros(1)])

        stim = interpolation_utils.in_ex_polate(x_old=stim_anchor_points_time,
                                                y_old=stim_anchor_points_amp,
                                                x_new=self.stim_time,
                                                kind=self.spline_mode)

        return stim, stim_anchor_points_time, stim_anchor_points_amp
예제 #8
0
    def update_target(self, rec_time=None):
        ''' Interpolate target for given recording time.
    '''

        if rec_time is not None:
            self.rec_time = np.asarray(rec_time)

        self.n_drop_trace = self.get_n_drop(self.rec_time, self.t_drop)

        orignal_time = self.target_original['Time'].values
        self.target_time = orignal_time[(orignal_time >= self.rec_time[0])
                                        & (orignal_time <= self.rec_time[-1])]

        self.n_drop_target = self.get_n_drop(self.target_time, self.t_drop)

        target_release = self.target_original['mean'].values

        self.target = interpolation_utils.in_ex_polate(
            x_old=orignal_time,
            y_old=target_release,
            x_new=self.target_time,
        )
예제 #9
0
def __numerical_comparison(time1, trace1, time2, trace2):

    time1 = np.asarray(time1)
    time2 = np.asarray(time2)
    trace1 = np.asarray(trace1)
    trace2 = np.asarray(trace2)

    if (time1.size == time2.size) and np.allclose(time1, time2):
        print('Times are the same.')

    else:
        print('Times are not the same. Use interpolation')
        time_interpol = time1[(time1 >= np.max([time1[0], time2[0]]))
                              & (time1 <= np.min([time1[-1], time2[-1]]))]

        trace1 = interpolation_utils.in_ex_polate(x_old=time1,
                                                  y_old=trace1,
                                                  x_new=time_interpol)
        trace2 = interpolation_utils.in_ex_polate(x_old=time2,
                                                  y_old=trace2,
                                                  x_new=time_interpol)

    if np.all(trace1 == trace2):
        print('Traces are exactly the same.')
    elif np.allclose(trace1, trace2, rtol=1e-2, atol=1e-2):
        max_err = np.max(np.abs(trace2 - trace1))
        print(
            'Traces are very close, differences might be due to rounding errors. Max error = {:.2g}'
            .format(max_err))
    else:
        max_err = np.max(np.abs(trace2 - trace1))
        print('Traces are not the equal. Max error = {:.2g}'.format(max_err))


#############################################################################
#def test_CBCs(self, filename_ON, filename_OFF, t_rng=None):
#  ON_model_output = data_utils.load_var(filename_ON)
#  OFF_model_output = data_utils.load_var(filename_OFF)
#
#  backup_stim = self.stim
#  backup_stim_type = self.stim_type
#
#  # Run with original stimulus.
#  assert np.all(ON_model_output['Stimulus'] == OFF_model_output['Stimulus'])
#  self.set_stim(ON_model_output['Stimulus'], stim_type='Light')
#
#  if t_rng is not None: self.update_t_rng(t_rng)
#  else:                 self.update_t_rng(cone_model_output['t_rng'])
#
#  try:
#    rec_data = self.run(
#      sim_params={'g_ac_hb': 0.0, 'g_db_ac': 0.0},
#      rec_type='test', plot=False, verbose=False, reset_retsim_stim=True,
#    )
#  except:
#    rec_data = None
#    pass
#
#  # Reset stimulus.
#  self.set_stim(backup_stim, backup_stim_type)
#
#  # Plot comparison.
#  if rec_data is not None:
#    plt.figure(figsize=(12,6))
#    plt.subplot(511)
#    plt.plot(rec_data[1]+self.get_t_rng()[0], math_utils.normalize(rec_data[2]), label='This model')
#    plt.plot(ON_model_output['Stimulus']['Time'], ON_model_output['Stimulus']['Stim'], '--', label='Target')
#    plt.legend()
#
#    plt.subplot(512)
#    plt.plot(rec_data[1]+self.get_t_rng()[0], rec_data[0]['rate BC ON'].mean(axis=1))
#    plt.plot(ON_model_output['Time']+ON_model_output['t_rng'][0], ON_model_output['rate'], '--')
#
#    plt.subplot(513)
#    plt.plot(rec_data[1]+self.get_t_rng()[0], rec_data[0]['BC Vm Soma ON'])
#    plt.plot(ON_model_output['Time']+ON_model_output['t_rng'][0], ON_model_output['Vm'], '--')
#
#    plt.subplot(514)
#    plt.plot(rec_data[1]+self.get_t_rng()[0], rec_data[0]['rate BC OFF'].mean(axis=1))
#    plt.plot(OFF_model_output['Time']+OFF_model_output['t_rng'][0], OFF_model_output['rate'], '--')
#
#    plt.subplot(515)
#    plt.plot(rec_data[1]+self.get_t_rng()[0], rec_data[0]['BC Vm Soma OFF'])
#    plt.plot(OFF_model_output['Time']+OFF_model_output['t_rng'][0], OFF_model_output['Vm'], '--')
#
#    plt.tight_layout()
#    plt.show()
#  else:
#    print('rec_data was None')
예제 #10
0
    def plot_loss_params(self):

        fig, axs = plt.subplots(len(self.loss_params),
                                1,
                                figsize=(12, len(self.loss_params) * 1.5))

        for ax, (loss_name, loss_dict) in zip(axs, self.loss_params.items()):
            ax.set_title(loss_name)

            if ('good' in loss_dict) and ('acceptable' in loss_dict):

                xticks = []
                lb = None
                ub = None

                if loss_dict['good'][0] is not None:
                    lb_g = loss_dict['good'][0]
                    xticks.append(lb_g)
                if loss_dict['acceptable'][0] is not None:
                    lb_a = loss_dict['acceptable'][0]
                    xticks.append(lb_a)
                if (loss_dict['good'][0]
                        is not None) and (loss_dict['acceptable'][0]
                                          is not None):
                    lb_rng = np.abs(lb_g - lb_a)
                    lb = lb_a - 0.5 * lb_rng

                if loss_dict['good'][1] is not None:
                    ub_g = loss_dict['good'][1]
                    xticks.append(ub_g)
                if loss_dict['acceptable'][1] is not None:
                    ub_a = loss_dict['acceptable'][1]
                    xticks.append(ub_a)
                if (loss_dict['good'][1]
                        is not None) and (loss_dict['acceptable'][1]
                                          is not None):
                    ub_rng = np.abs(ub_g - ub_a)
                    ub = ub_a + 0.5 * ub_rng

                if lb is None:
                    lb = ub - 2 * ub_rng
                elif ub is None:
                    ub = lb + 2 * lb_rng

                in_values = np.linspace(lb, ub, 100)
                out_values = np.full(in_values.size, np.nan)

                for idx, value in enumerate(in_values):
                    out_values[idx] = self.loss_value_in_range(
                        value=value,
                        good=loss_dict['good'],
                        acceptable=loss_dict['acceptable'])

                ax.plot(in_values, out_values)
                ax.set_xticks(xticks)

                for xtick in xticks:
                    ax.axvline(xtick, c='gray', ls='--')
                if not self.absolute:
                    ax.axhline(0, c='gray', ls='--')

            elif 'iGluSnFR' in loss_name:
                plot_losses_dict = {}
                plot_losses_dict['sinus'] = self.loss_iGluSnFR(
                    np.sin(10 * self.rec_time))
                plot_losses_dict['zeros'] = self.loss_iGluSnFR(
                    np.zeros(self.rec_time.size))
                plot_losses_dict['ones'] = self.loss_iGluSnFR(
                    np.zeros(self.rec_time.size))
                plot_losses_dict['slope'] = self.loss_iGluSnFR(
                    -np.arange(self.rec_time.size))
                plot_losses_dict['noise'] = self.loss_iGluSnFR(
                    np.random.uniform(-1, 1, self.rec_time.size))

                interpol_target = interpolation_utils.in_ex_polate(
                    self.target_time, self.target, self.rec_time)

                plot_losses_dict['noisy target'] = self.loss_iGluSnFR(
                    interpol_target + np.random.normal(0, np.std(self.target),
                                                       self.rec_time.size))
                plot_losses_dict['flipped target'] = self.loss_iGluSnFR(
                    -interpol_target)

                ax.set_yscale('log')
                ax.bar(np.arange(len(plot_losses_dict)),
                       plot_losses_dict.values())
                ax.set_xticks(np.arange(len(plot_losses_dict)))
                ax.set_xticklabels(plot_losses_dict.keys())
                for idx, plot_losses_value in enumerate(
                        plot_losses_dict.values()):
                    ax.text(idx,
                            1,
                            "{:.4f}".format(plot_losses_value),
                            ha='center',
                            va='top')

        plt.tight_layout()
        plt.show()