Exemplo n.º 1
0
def dendrite_model__parameter_sweep(data_file_list, L_di_vec, tau_di_vec,
                                    dt_vec, tf_vec, drive_info, amp_vec,
                                    mu1_vec, mu2_vec, mu3_vec, mu4_vec,
                                    master_error_plot_name):

    master_error_plot_name = 'mstr_err__' + master_error_plot_name

    best_params = dict()
    best_params['amp_mu12'] = []
    best_params['amp_mu34'] = []
    best_params['mu1'] = []
    best_params['mu2'] = []
    best_params['mu3'] = []
    best_params['mu4'] = []

    data_array = dict()
    data_array['amp_vec'] = amp_vec
    data_array['mu1_vec'] = mu1_vec
    data_array['mu2_vec'] = mu2_vec
    data_array['mu3_vec'] = mu3_vec
    data_array['mu4_vec'] = mu4_vec

    num_sims = len(data_file_list)
    num_amps = len(amp_vec)
    num_mu1 = len(mu1_vec)
    num_mu2 = len(mu2_vec)
    num_mu3 = len(mu3_vec)
    num_mu4 = len(mu4_vec)

    mu3_initial = 0  #1
    mu4_initial = 0  #0.5

    print(
        '\n\nrunning dendrite_model__parameter_sweep\n\nnum_files = {:d}\nnum_amps = {:d}\nnum_mu1 = {:d}\nnum_mu2 = {:d}\nnum_mu3 = {:d}\nnum_mu4 = {:d}'
        .format(num_sims, num_amps, num_mu1, num_mu2, num_mu3, num_mu4))

    error_mat_master__mu1_mu2 = np.zeros([num_amps, num_mu1, num_mu2])
    error_mat_master__mu3_mu4 = np.zeros([num_amps, num_mu3, num_mu4])

    if drive_info['drive_type'] == 'piecewise_linear':
        pwl_drive = drive_info['pwl_drive']
        directory_string = 'constant_drive'
        wr_drive_string = '@I0[c]'
        wr_target_string = 'L9#branch'
    if drive_info['drive_type'] == 'linear_ramp':
        pwl_drive = drive_info['pwl_drive']
        directory_string = 'linear_ramp'
        wr_drive_string = '@I0[c]'
        wr_target_string = 'L9#branch'
    if drive_info['drive_type'] == 'sq_pls_trn':
        sq_pls_trn_params = drive_info['sq_pls_trn_params']
        directory_string = 'square_pulse_sequence'
        wr_drive_string = '@I0[c]'
        wr_target_string = 'L9#branch'
    if drive_info['drive_type'] == 'exp_pls_trn':
        exp_pls_trn_params = drive_info['exp_pls_trn_params']
        directory_string = 'exponential_pulse_sequence'
        wr_drive_string = 'L5#branch'
        wr_target_string = 'L10#branch'

    for ii in range(num_sims):
        print('\ndata_file {} of {}\n'.format(ii + 1, num_sims))
        plt.close('all')

        dt = dt_vec[ii]
        tf = tf_vec[ii]

        # WR data
        print('reading wr data ...\n')
        file_name = data_file_list[ii] + '.dat'
        data_dict = read_wr_data('wrspice_data/' + directory_string + '/' +
                                 file_name)
        target_data = np.vstack(
            (data_dict['time'], data_dict[wr_target_string]))

        #----------------------
        # compare drive signals
        #----------------------
        target_data__drive = np.vstack(
            (data_dict['time'], data_dict[wr_drive_string]))

        # initialize input signal
        if drive_info['drive_type'] == 'piecewise_linear' or drive_info[
                'drive_type'] == 'linear_ramp':
            input_1 = input_signal(
                'input_dendritic_drive',
                input_temporal_form='analog_dendritic_drive',
                output_inductance=200e-12,
                time_vec=np.arange(0, tf + dt, dt),
                piecewise_linear=pwl_drive)
            dendritic_drive = dendritic_drive__piecewise_linear(
                input_1.time_vec, pwl_drive)
        if drive_info['drive_type'] == 'sq_pls_trn':
            input_1 = input_signal(
                'input_dendritic_drive',
                input_temporal_form='analog_dendritic_drive',
                output_inductance=200e-12,
                time_vec=np.arange(0, tf + dt, dt),
                square_pulse_train=sq_pls_trn_params)
            dendritic_drive = dendritic_drive__square_pulse_train(
                input_1.time_vec, sq_pls_trn_params)
        if drive_info['drive_type'] == 'exp_pls_trn':
            input_1 = input_signal(
                'input_dendritic_drive',
                input_temporal_form='analog_dendritic_drive',
                output_inductance=200e-12,
                time_vec=np.arange(0, tf + dt, dt),
                exponential_pulse_train=exp_pls_trn_params)
            dendritic_drive = dendritic_drive__exp_pls_train__LR(
                input_1.time_vec, exp_pls_trn_params)

        actual_data__drive = np.vstack(
            (input_1.time_vec[:], dendritic_drive[:, 0]))
        print('comparing drive signals ...\n')
        error__drive = chi_squared_error(target_data__drive,
                                         actual_data__drive)

        #------------------------
        # find best amp, mu1, mu2
        #------------------------
        mu3 = [mu3_initial]  #np.linspace([0.5,2.5,10])
        mu4 = [mu4_initial]  #np.linspace([0.25,1.5,10])

        error_mat_1 = np.zeros([num_amps, num_mu1, num_mu2])
        print('seeking amp, mu1, mu2 ...')
        for aa in range(num_amps):
            for bb in range(num_mu1):
                for cc in range(num_mu2):

                    print('aa = {} of {}, bb = {} of {}, cc = {} of {}'.format(
                        aa + 1, num_amps, bb + 1, num_mu1, cc + 1, num_mu2))

                    # create sim_params dictionary
                    sim_params = dict()
                    sim_params['amp'] = amp_vec[aa]
                    sim_params['mu1'] = mu1_vec[bb]
                    sim_params['mu2'] = mu2_vec[cc]
                    sim_params['mu3'] = mu3
                    sim_params['mu4'] = mu4
                    sim_params['dt'] = dt
                    sim_params['tf'] = tf

                    # initialize dendrite
                    dendrite_1 = dendrite(
                        'dendrite_under_test',
                        inhibitory_or_excitatory='excitatory',
                        circuit_inductances=[
                            10e-12, 26e-12, 200e-12, 77.5e-12
                        ],
                        input_synaptic_connections=[],
                        input_synaptic_inductances=[[]],
                        input_dendritic_connections=[],
                        input_dendritic_inductances=[[]],
                        input_direct_connections=['input_dendritic_drive'],
                        input_direct_inductances=[[10e-12, 1]],
                        thresholding_junction_critical_current=40e-6,
                        bias_currents=[72e-6, 29e-6, 35e-6],
                        integration_loop_self_inductance=L_di_vec[ii],
                        integration_loop_output_inductance=0e-12,
                        integration_loop_temporal_form='exponential',
                        integration_loop_time_constant=tau_di_vec[ii],
                        integration_loop_saturation_current=11.75e-6,
                        dendrite_model_params=sim_params)

                    dendrite_1.run_sim()

                    actual_data = np.vstack(
                        (input_1.time_vec[:], dendrite_1.I_di[:, 0]))
                    error_mat_1[aa, bb, cc] = chi_squared_error(
                        target_data, actual_data)

        error_mat_master__mu1_mu2 += error_mat_1
        ind_best = np.where(
            error_mat_1 == np.amin(error_mat_1))  #error_mat.argmin()
        amp_best_mu12 = amp_vec[ind_best[0]][0]
        mu1_best = mu1_vec[ind_best[1]][0]
        mu2_best = mu2_vec[ind_best[2]][0]
        print('\n\namp_best_mu12 = {}'.format(amp_best_mu12))
        print('mu1_best = {}'.format(mu1_best))
        print('mu2_best = {}\n\n'.format(mu2_best))
        best_params['amp_mu12'].append(amp_best_mu12)
        best_params['mu1'].append(mu1_best)
        best_params['mu2'].append(mu2_best)
        data_array['error_mat__amp_mu1_mu2'] = error_mat_1

        #plot errors
        title_string = '{}\namp_best_mu12 = {:2.2f}, mu1_best = {:1.2f}, mu2_best = {:1.2f}'.format(
            data_file_list[ii], amp_best_mu12, mu1_best, mu2_best)
        save_str = '{}__error__mu1_mu2'.format(data_file_list[ii])
        for aa in range(num_amps):
            plot_error_mat(error_mat_1[aa, :, :], mu1_vec, mu2_vec, 'mu1',
                           'mu2', 'amp = {}'.format(amp_vec[aa]), title_string,
                           save_str)

        #repeat best one and plot
        if drive_info['drive_type'] == 'piecewise_linear' or drive_info[
                'drive_type'] == 'linear_ramp':
            input_1 = input_signal(
                'input_dendritic_drive',
                input_temporal_form='analog_dendritic_drive',
                output_inductance=200e-12,
                time_vec=np.arange(0, tf + dt, dt),
                piecewise_linear=pwl_drive)
            dendritic_drive = dendritic_drive__piecewise_linear(
                input_1.time_vec, pwl_drive)
        if drive_info['drive_type'] == 'sq_pls_trn':
            input_1 = input_signal(
                'input_dendritic_drive',
                input_temporal_form='analog_dendritic_drive',
                output_inductance=200e-12,
                time_vec=np.arange(0, tf + dt, dt),
                square_pulse_train=sq_pls_trn_params)
            dendritic_drive = dendritic_drive__square_pulse_train(
                input_1.time_vec, sq_pls_trn_params)
        if drive_info['drive_type'] == 'exp_pls_trn':
            input_1 = input_signal(
                'input_dendritic_drive',
                input_temporal_form='analog_dendritic_drive',
                output_inductance=200e-12,
                time_vec=np.arange(0, tf + dt, dt),
                exponential_pulse_train=exp_pls_trn_params)
            dendritic_drive = dendritic_drive__exp_pls_train__LR(
                input_1.time_vec, exp_pls_trn_params)

        sim_params = dict()
        sim_params['amp'] = amp_best_mu12
        sim_params['mu1'] = mu1_best
        sim_params['mu2'] = mu2_best
        sim_params['mu3'] = mu3
        sim_params['mu4'] = mu4
        sim_params['dt'] = dt
        sim_params['tf'] = tf

        dendrite_1 = dendrite(
            'dendrite_under_test',
            inhibitory_or_excitatory='excitatory',
            circuit_inductances=[10e-12, 26e-12, 200e-12, 77.5e-12],
            input_synaptic_connections=[],
            input_synaptic_inductances=[[]],
            input_dendritic_connections=[],
            input_dendritic_inductances=[[]],
            input_direct_connections=['input_dendritic_drive'],
            input_direct_inductances=[[10e-12, 1]],
            thresholding_junction_critical_current=40e-6,
            bias_currents=[72e-6, 29e-6, 35e-6],
            integration_loop_self_inductance=L_di_vec[ii],
            integration_loop_output_inductance=0e-12,
            integration_loop_temporal_form='exponential',
            integration_loop_time_constant=tau_di_vec[ii],
            integration_loop_saturation_current=11.75e-6,
            dendrite_model_params=sim_params)

        dendrite_1.run_sim()
        # plot_dendritic_integration_loop_current(dendrite_1)
        actual_data = np.vstack((input_1.time_vec[:], dendrite_1.I_di[:, 0]))
        main_title = '{}\namp_best_{:2.2f}_mu1_best_{:1.4f}_mu2_best_{:1.4f}'.format(
            data_file_list[ii], amp_best_mu12, mu1_best, mu2_best)
        error__signal = np.amin(error_mat_1)
        plot_wr_comparison__drive_and_response(main_title, target_data__drive,
                                               actual_data__drive, target_data,
                                               actual_data, data_file_list[ii],
                                               error__drive, error__signal)

        #-----------------------------
        # find best amp_mu34, mu3, mu4
        #-----------------------------
        error_mat_2 = np.zeros([num_amps, num_mu3, num_mu4])
        print('seeking amp, mu3, mu4 ...')
        for aa in range(num_amps):
            for bb in range(num_mu1):
                for cc in range(num_mu2):

                    print('aa = {} of {}, bb = {} of {}, cc = {} of {}'.format(
                        aa + 1, num_amps, bb + 1, num_mu3, cc + 1, num_mu4))

                    # initialize input signal
                    if drive_info[
                            'drive_type'] == 'piecewise_linear' or drive_info[
                                'drive_type'] == 'linear_ramp':
                        input_1 = input_signal(
                            'input_dendritic_drive',
                            input_temporal_form='analog_dendritic_drive',
                            output_inductance=200e-12,
                            time_vec=np.arange(0, tf + dt, dt),
                            piecewise_linear=pwl_drive)
                        dendritic_drive = dendritic_drive__piecewise_linear(
                            input_1.time_vec, pwl_drive)
                    if drive_info['drive_type'] == 'sq_pls_trn':
                        input_1 = input_signal(
                            'input_dendritic_drive',
                            input_temporal_form='analog_dendritic_drive',
                            output_inductance=200e-12,
                            time_vec=np.arange(0, tf + dt, dt),
                            square_pulse_train=sq_pls_trn_params)
                        dendritic_drive = dendritic_drive__square_pulse_train(
                            input_1.time_vec, sq_pls_trn_params)
                    if drive_info['drive_type'] == 'exp_pls_trn':
                        input_1 = input_signal(
                            'input_dendritic_drive',
                            input_temporal_form='analog_dendritic_drive',
                            output_inductance=200e-12,
                            time_vec=np.arange(0, tf + dt, dt),
                            exponential_pulse_train=exp_pls_trn_params)
                        dendritic_drive = dendritic_drive__exp_pls_train__LR(
                            input_1.time_vec, exp_pls_trn_params)

                    # create sim_params dictionary
                    sim_params = dict()
                    sim_params['amp'] = amp_vec[aa]
                    sim_params['mu1'] = mu1_best
                    sim_params['mu2'] = mu2_best
                    sim_params['mu3'] = mu3_vec[bb]
                    sim_params['mu4'] = mu4_vec[cc]
                    sim_params['dt'] = dt
                    sim_params['tf'] = tf

                    # initialize dendrite
                    dendrite_1 = dendrite(
                        'dendrite_under_test',
                        inhibitory_or_excitatory='excitatory',
                        circuit_inductances=[
                            10e-12, 26e-12, 200e-12, 77.5e-12
                        ],
                        input_synaptic_connections=[],
                        input_synaptic_inductances=[[]],
                        input_dendritic_connections=[],
                        input_dendritic_inductances=[[]],
                        input_direct_connections=['input_dendritic_drive'],
                        input_direct_inductances=[[10e-12, 1]],
                        thresholding_junction_critical_current=40e-6,
                        bias_currents=[72e-6, 29e-6, 35e-6],
                        integration_loop_self_inductance=L_di_vec[ii],
                        integration_loop_output_inductance=0e-12,
                        integration_loop_temporal_form='exponential',
                        integration_loop_time_constant=tau_di_vec[ii],
                        integration_loop_saturation_current=11.75e-6,
                        dendrite_model_params=sim_params)

                    dendrite_1.run_sim()

                    actual_data = np.vstack(
                        (input_1.time_vec[:], dendrite_1.I_di[:, 0]))
                    error_mat_2[aa, bb, cc] = chi_squared_error(
                        target_data, actual_data)

        error_mat_master__mu3_mu4 += error_mat_2
        ind_best = np.where(
            error_mat_2 == np.amin(error_mat_2))  #error_mat.argmin()
        amp_best_mu34 = amp_vec[ind_best[0]][0]
        mu3_best = mu3_vec[ind_best[1]][0]
        mu4_best = mu4_vec[ind_best[2]][0]
        print('\n\namp_best_mu34 = {}'.format(amp_best_mu34))
        print('mu3_best = {}'.format(mu3_best))
        print('mu4_best = {}'.format(mu4_best))
        best_params['amp_mu34'].append(amp_best_mu34)
        best_params['mu3'].append(mu3_best)
        best_params['mu4'].append(mu4_best)
        data_array['error_mat__mu3_mu4'] = error_mat_2

        #plot errors
        title_string = '{}\namp_best_mu12 = {:2.2f}, amp_best_mu34 = {:2.2f}, mu1_best = {:1.2f}, mu2_best = {:1.2f}, mu3_best = {:1.2f}, mu4_best = {:1.2f}'.format(
            data_file_list[ii], amp_best_mu12, amp_best_mu34, mu1_best,
            mu2_best, mu3_best, mu4_best)
        save_str = '{}__error__mu3_mu4'.format(data_file_list[ii])
        for aa in range(num_amps):
            plot_error_mat(error_mat_2[aa, :, :], mu3_vec, mu4_vec, 'mu3',
                           'mu4', 'amp = {}'.format(amp_vec[aa]), title_string,
                           save_str)

        #repeat best one and plot
        if drive_info['drive_type'] == 'piecewise_linear' or drive_info[
                'drive_type'] == 'linear_ramp':
            input_1 = input_signal(
                'input_dendritic_drive',
                input_temporal_form='analog_dendritic_drive',
                output_inductance=200e-12,
                time_vec=np.arange(0, tf + dt, dt),
                piecewise_linear=pwl_drive)
            dendritic_drive = dendritic_drive__piecewise_linear(
                input_1.time_vec, pwl_drive)
        if drive_info['drive_type'] == 'sq_pls_trn':
            input_1 = input_signal(
                'input_dendritic_drive',
                input_temporal_form='analog_dendritic_drive',
                output_inductance=200e-12,
                time_vec=np.arange(0, tf + dt, dt),
                square_pulse_train=sq_pls_trn_params)
            dendritic_drive = dendritic_drive__square_pulse_train(
                input_1.time_vec, sq_pls_trn_params)
        if drive_info['drive_type'] == 'exp_pls_trn':
            input_1 = input_signal(
                'input_dendritic_drive',
                input_temporal_form='analog_dendritic_drive',
                output_inductance=200e-12,
                time_vec=np.arange(0, tf + dt, dt),
                exponential_pulse_train=exp_pls_trn_params)
            dendritic_drive = dendritic_drive__exp_pls_train__LR(
                input_1.time_vec, exp_pls_trn_params)

        sim_params = dict()
        sim_params['amp'] = amp_best_mu34
        sim_params['mu1'] = mu1_best
        sim_params['mu2'] = mu2_best
        sim_params['mu3'] = mu3_best
        sim_params['mu4'] = mu4_best
        sim_params['dt'] = dt
        sim_params['tf'] = tf

        dendrite_1 = dendrite(
            'dendrite_under_test',
            inhibitory_or_excitatory='excitatory',
            circuit_inductances=[10e-12, 26e-12, 200e-12, 77.5e-12],
            input_synaptic_connections=[],
            input_synaptic_inductances=[[]],
            input_dendritic_connections=[],
            input_dendritic_inductances=[[]],
            input_direct_connections=['input_dendritic_drive'],
            input_direct_inductances=[[10e-12, 1]],
            thresholding_junction_critical_current=40e-6,
            bias_currents=[72e-6, 29e-6, 35e-6],
            integration_loop_self_inductance=L_di_vec[ii],
            integration_loop_output_inductance=0e-12,
            integration_loop_temporal_form='exponential',
            integration_loop_time_constant=tau_di_vec[ii],
            integration_loop_saturation_current=11.75e-6,
            dendrite_model_params=sim_params)

        dendrite_1.run_sim()
        main_title = '{}\namp_best_mu12_{:2.2f}_amp_best_mu34_{:2.2f}_mu1_best_{:1.2f}_mu2_best_{:1.2f}_mu3_best_{:1.2f}_mu4_best_{:1.2f}'.format(
            data_file_list[ii], amp_best_mu12, amp_best_mu34, mu1_best,
            mu2_best, mu3_best, mu4_best)
        actual_data = np.vstack((input_1.time_vec[:], dendrite_1.I_di[:, 0]))
        error__signal = np.amin(error_mat_2)
        plot_wr_comparison__drive_and_response(main_title, target_data__drive,
                                               actual_data__drive, target_data,
                                               actual_data, data_file_list[ii],
                                               error__drive, error__signal)

    # save data
    save_string = 'wr_fits__finding_amp_mu1_mu2+mu3_mu4'
    data_array['wr_spice_data_file_list'] = data_file_list
    data_array['best_params'] = best_params
    data_array['error_mat_master__mu1_mu2'] = error_mat_master__mu1_mu2
    data_array['error_mat_master__mu3_mu4'] = error_mat_master__mu3_mu4
    print('\n\nsaving session data ...')
    save_session_data(data_array, save_string)

    #plot errors
    title_string = '{}\namp_best_mu12 = {:2.2f}, amp_best_mu34 = {:2.2f}, mu1_best = {:1.2f}, mu2_best = {:1.2f}, mu3_best = {:1.2f}, mu4_best = {:1.2f}'.format(
        master_error_plot_name, amp_best_mu12, amp_best_mu34, mu1_best,
        mu2_best, mu3_best, mu4_best)
    for aa in range(num_amps):
        save_str_1 = '{}__master_error__mu1_mu2__amp_{:2.2f}'.format(
            master_error_plot_name, amp_vec[aa])
        save_str_2 = '{}__master_error__mu3_mu4__amp_{:2.2f}'.format(
            master_error_plot_name, amp_vec[aa])
        plot_error_mat(error_mat_master__mu1_mu2[aa, :, :], mu1_vec, mu2_vec,
                       'mu1', 'mu2', 'amp = {}'.format(amp_vec[aa]),
                       title_string, save_str_1)
        plot_error_mat(error_mat_master__mu3_mu4[aa, :, :], mu3_vec, mu4_vec,
                       'mu3', 'mu4', 'amp = {}'.format(amp_vec[aa]),
                       title_string, save_str_2)

    return best_params, error_mat_master__mu1_mu2, error_mat_master__mu3_mu4
#%% assemble and plot error_mat_masters
error_mat_master__mu1_mu2 = error_mat_master__mu1_mu2__no_leak__vary_Ldi+error_mat_master__mu1_mu2__vary_taudi+error_mat_master__mu1_mu2__vary_Idrive
error_mat_master__mu3_mu4 = error_mat_master__mu3_mu4__no_leak__vary_Ldi+error_mat_master__mu3_mu4__vary_taudi+error_mat_master__mu3_mu4__vary_Idrive

ind_best = np.where(error_mat_master__mu1_mu2 == np.amin(error_mat_master__mu1_mu2))#error_mat.argmin()
amp_best_mu12 = amp_vec[ind_best[0]][0]
mu1_best = mu1_vec[ind_best[1]][0]
mu2_best = mu2_vec[ind_best[2]][0]
ind_best = np.where(error_mat_master__mu3_mu4 == np.amin(error_mat_master__mu3_mu4))#error_mat.argmin()
amp_best_mu34 = amp_vec[ind_best[0]][0]
mu3_best = mu3_vec[ind_best[1]][0]
mu4_best = mu4_vec[ind_best[2]][0]
print('\n\namp_best_mu12 = {}'.format(amp_best_mu12))
print('mu1_best = {}'.format(mu1_best))
print('mu2_best = {}\n\n'.format(mu2_best))
print('\n\namp_best_mu34 = {}'.format(amp_best_mu34))
print('mu3_best = {}'.format(mu3_best))
print('mu4_best = {}\n\n'.format(mu4_best))

pre_str = 'dend_fit_master'
title_string = '{}; amp_best_mu12 = {:2.2f}, amp_best_mu34 = {:2.2f}, mu1_best = {:1.2f}, mu2_best = {:1.2f}, mu3_best = {:1.2f}, mu4_best = {:1.2f}'.format(pre_str,amp_best_mu12,amp_best_mu34,mu1_best,mu2_best,mu3_best,mu4_best)    
for aa in range(len(amp_vec)):
    save_str_1 = '{}__master_error__mu1_mu2__amp_{:2.2f}'.format(pre_str,amp_vec[aa])
    save_str_2 = '{}__master_error__mu3_mu4__amp_{:2.2f}'.format(pre_str,amp_vec[aa])
    plot_error_mat(error_mat_master__mu1_mu2[aa,:,:],mu1_vec,mu2_vec,'mu1','mu2','amp = {}'.format(amp_vec[aa]),title_string,save_str_1)
    plot_error_mat(error_mat_master__mu3_mu4[aa,:,:],mu3_vec,mu4_vec,'mu3','mu4','amp = {}'.format(amp_vec[aa]),title_string,save_str_2)

#%% load data
# load_string = 'session_data__wr_fits__no_leak__vary_L_di__finding_amp_mu1_mu2+mu3_mu4__2020-04-08_14-29-15.dat'
# data_array_imported = load_session_data(load_string)
Exemplo n.º 3
0
                             error_mat_master__mu1_mu2__exp_pls_trn_per100ns)

# error_mat_master__mu1_mu2 = ( error_mat_master__mu1_mu2__no_leak__vary_Ldi
#                              +error_mat_master__mu1_mu2__vary_taudi
#                              +error_mat_master__mu1_mu2__vary_Idrive_20uA
#                              +error_mat_master__mu1_mu2__vary_Idrive_25uA
#                              +error_mat_master__mu1_mu2__vary_Idrive_30uA
#                              +error_mat_master__mu1_mu2__sq_pls_trn_per20ns
#                              +error_mat_master__mu1_mu2__sq_pls_trn_per100ns
#                              +error_mat_master__mu1_mu2__exp_pls_trn_per100ns
#                              +error_mat_master__mu1_mu2__exp_pls_trn_per1000ns )

ind_best = np.where(error_mat_master__mu1_mu2 == np.amin(
    error_mat_master__mu1_mu2))  #error_mat.argmin()
mu1_best = mu1_vec[ind_best[0]][0]
mu2_best = mu2_vec[ind_best[1]][0]
print('mu1_best = {}'.format(mu1_best))
print('mu2_best = {}\n\n'.format(mu2_best))

pre_str = 'dend_fit_master'
title_string = '{}; mu1_best = {:1.2f}, mu2_best = {:1.2f}'.format(
    pre_str, mu1_best, mu2_best)
save_str_1 = '{}__master_error__mu1_mu2'.format(pre_str)
plot_error_mat(error_mat_master__mu1_mu2[:, :], mu1_vec, mu2_vec, 'mu1', 'mu2',
               title_string, save_str_1)

#%% load data
# load_string = 'session_data__wr_fits__no_leak__vary_L_di__finding_amp_mu1_mu2+mu3_mu4__2020-04-08_14-29-15.dat'
# data_array_imported = load_session_data(load_string)
Ljj(40e-6, 35e-6)
Exemplo n.º 4
0
def synapse_model__parameter_sweep(data_file_list, I_sy_vec, L_si_vec,
                                   tau_si_vec, dt, tf, spike_times, gamma1_vec,
                                   gamma2_vec, gamma3_vec,
                                   master_error_plot_name):

    gamma3_init = 1

    master_error_plot_name = 'mstr_err__' + master_error_plot_name

    best_params = dict()
    best_params['gamma1'] = []
    best_params['gamma2'] = []
    best_params['gamma3'] = []

    data_array = dict()
    data_array['gamma1_vec'] = gamma1_vec
    data_array['gamma2_vec'] = gamma2_vec
    data_array['gamma3_vec'] = gamma3_vec

    num_sims = len(data_file_list)
    num_gamma1 = len(gamma1_vec)
    num_gamma2 = len(gamma2_vec)
    num_gamma3 = len(gamma3_vec)

    print(
        '\n\nrunning synapse_model__parameter_sweep\n\nnum_files = {:d}\nnum_gamma1 = {:d}\nnum_gamma2 = {:d}'
        .format(num_sims, num_gamma1, num_gamma2))

    #------------------------
    # find best gamma1, gamma2
    #------------------------

    error_mat_master__gamma1_gamma2 = np.zeros([num_gamma1, num_gamma2])
    for ii in range(num_sims):
        print('\ndata_file {} of {}\n'.format(ii + 1, num_sims))
        # plt.close('all')

        # WR data
        directory = 'wrspice_data/fitting_data'
        file_name = data_file_list[ii]
        data_dict = read_wr_data(directory + '/' + file_name)
        target_data = np.vstack((data_dict['time'], data_dict['L3#branch']))
        wr_drive = np.vstack((data_dict['time'], data_dict['L0#branch']))

        # initialize input signal
        name__i = 'in'
        input_1 = input_signal(name__i,
                               input_temporal_form='arbitrary_spike_train',
                               spike_times=spike_times)

        error_mat_1 = np.zeros([num_gamma1, num_gamma2])
        print('\nseeking gamma1, gamma2 ...\n')
        for aa in range(num_gamma1):
            for bb in range(num_gamma2):

                print('aa = {} of {}, bb = {} of {}'.format(
                    aa + 1, num_gamma1, bb + 1, num_gamma2))

                # create sim_params dictionary
                sim_params = dict()
                sim_params['gamma1'] = gamma1_vec[aa]
                sim_params['gamma2'] = gamma2_vec[bb]
                sim_params['gamma3'] = gamma3_init
                sim_params['dt'] = dt
                sim_params['tf'] = tf

                # initialize synapse
                name_s = 'sy'
                synapse_1 = synapse(
                    name_s,
                    integration_loop_temporal_form='exponential',
                    integration_loop_time_constant=tau_si_vec[ii],
                    integration_loop_self_inductance=L_si_vec[ii],
                    integration_loop_output_inductance=0e-12,
                    synaptic_bias_current=I_sy_vec[ii],
                    integration_loop_bias_current=35e-6,
                    input_signal_name='in',
                    synapse_model_params=sim_params)

                synapse_1.run_sim()

                actual_data = np.vstack(
                    (synapse_1.time_vec[:], synapse_1.I_si[:, 0]))
                error_mat_1[aa,
                            bb] = chi_squared_error(target_data, actual_data)

                plot_wr_comparison__synapse(
                    file_name + '; gamma1 = {:f}, gamma2 = {:f}'.format(
                        gamma1_vec[aa], gamma2_vec[bb]), spike_times, wr_drive,
                    target_data, actual_data, file_name, error_mat_1[aa, bb])

        error_mat_master__gamma1_gamma2 += error_mat_1
        ind_best = np.where(
            error_mat_1 == np.amin(error_mat_1))  #error_mat.argmin()
        gamma1_best = gamma1_vec[ind_best[0]][0]
        gamma2_best = gamma2_vec[ind_best[1]][0]
        print('gamma1_best = {}'.format(gamma1_best))
        print('gamma2_best = {}\n\n'.format(gamma2_best))
        best_params['gamma1'].append(gamma1_best)
        best_params['gamma2'].append(gamma2_best)
        data_array['error_mat__gamma1_gamma2'] = error_mat_1

        #plot errors
        # title_string = '{}\ngamma1_best = {:1.2f}, gamma2_best = {:1.2f}'.format(data_file_list[ii],gamma1_best,gamma2_best)
        # save_str = '{}__error__gamma1_gamma2'.format(data_file_list[ii])
        # plot_error_mat(error_mat_1[:,:],gamma1_vec,gamma2_vec,'gamma1','gamma2',title_string,save_str)

        # #repeat best one and plot
        sim_params = dict()
        sim_params['gamma1'] = gamma1_best
        sim_params['gamma2'] = gamma2_best
        sim_params['gamma3'] = gamma3_init
        sim_params['dt'] = dt
        sim_params['tf'] = tf

        # initialize synapse
        name_s = 'sy'
        synapse_1 = synapse(name_s,
                            integration_loop_temporal_form='exponential',
                            integration_loop_time_constant=tau_si_vec[ii],
                            integration_loop_self_inductance=L_si_vec[ii],
                            integration_loop_output_inductance=0e-12,
                            synaptic_bias_current=I_sy_vec[ii],
                            integration_loop_bias_current=35e-6,
                            input_signal_name='in',
                            synapse_model_params=sim_params)

        synapse_1.run_sim()

        actual_data = np.vstack((synapse_1.time_vec[:], synapse_1.I_si[:, 0]))
        error__si = chi_squared_error(target_data, actual_data)

        main_title = '{}\ngamma1_best_{:6.4f}_gamma2_best_{:6.4f}'.format(
            data_file_list[ii], gamma1_best, gamma2_best)
        plot_wr_comparison__synapse(main_title, spike_times, wr_drive,
                                    target_data, actual_data, file_name,
                                    error__si)

    # # save data
    # save_string = 'wr_fits__finding_gamma1_gamma2'
    # data_array['wr_spice_data_file_list'] = data_file_list
    # data_array['best_params'] = best_params
    # data_array['error_mat_master__gamma1_gamma2'] = error_mat_master__gamma1_gamma2
    # print('\n\nsaving session data ...')
    # save_session_data(data_array,save_string)

    #plot errors
    title_string = '{}; gamma1_best = {:6.4f}, gamma2_best = {:6.4f}'.format(
        master_error_plot_name, gamma1_best, gamma2_best)
    save_str_1 = '{}__master_error__gamma1_gamma2'.format(
        master_error_plot_name)
    plot_error_mat(error_mat_master__gamma1_gamma2[:, :], gamma1_vec,
                   gamma2_vec, 'gamma1', 'gamma2', title_string, save_str_1)

    #-----------------
    # find best gamma3
    #-----------------
    # error_mat_master__gamma3 = np.zeros([num_gamma3])
    # for ii in range(num_sims):
    #     print('\ndata_file {} of {}\n'.format(ii+1,num_sims))
    #     # plt.close('all')

    #     # WR data
    #     directory = 'wrspice_data/fitting_data'
    #     file_name = data_file_list[ii]
    #     data_dict = read_wr_data(directory+'/'+file_name)
    #     target_data = np.vstack((data_dict['time'],data_dict['L3#branch']))
    #     wr_drive = np.vstack((data_dict['time'],data_dict['L0#branch']))

    #     # initialize input signal
    #     name__i = 'in'
    #     input_1 = input_signal(name__i, input_temporal_form = 'arbitrary_spike_train', spike_times = spike_times)

    #     error_mat_2 = np.zeros([num_gamma3])
    #     print('\nseeking gamma3 ...\n')
    #     for aa in range(num_gamma3):

    #         print('aa = {} of {}'.format(aa+1,num_gamma3))

    #         # create sim_params dictionary
    #         sim_params = dict()
    #         sim_params['gamma1'] = gamma1_best
    #         sim_params['gamma2'] = gamma2_best
    #         sim_params['gamma3'] = gamma3_vec[aa]
    #         sim_params['dt'] = dt
    #         sim_params['tf'] = tf

    #         # initialize synapse
    #         name_s = 'sy'
    #         synapse_1 = synapse(name_s, integration_loop_temporal_form = 'exponential', integration_loop_time_constant = tau_si_vec[ii],
    #                             integration_loop_self_inductance = L_si_vec[ii], integration_loop_output_inductance = 0e-12,
    #                             synaptic_bias_current = I_sy_vec[ii], integration_loop_bias_current = 35e-6,
    #                             input_signal_name = 'in', synapse_model_params = sim_params)

    #         synapse_1.run_sim()

    #         actual_data = np.vstack((synapse_1.time_vec[:],synapse_1.I_si[:,0]))
    #         error_mat_2[aa] = chi_squared_error(target_data,actual_data)

    #         # plot_wr_comparison__synapse(file_name+'gamma1 = {:f}, gamma2 = {:f}'.format(gamma1_vec[aa],gamma2_vec[bb]),spike_times,wr_drive,target_data,actual_data,file_name,error_mat_1[aa,bb])

    #     error_mat_master__gamma3 += error_mat_2
    #     ind_best = np.where(error_mat_2 == np.amin(error_mat_2))#error_mat.argmin()
    #     gamma3_best = gamma3_vec[ind_best[0]][0]
    #     print('gamma3_best = {}'.format(gamma3_best))
    #     best_params['gamma3'].append(gamma3_best)
    #     data_array['error_mat__gamma3'] = error_mat_2

    #     #plot errors
    #     # title_string = '{}\ngamma1_best = {:1.2f}, gamma2_best = {:1.2f}'.format(data_file_list[ii],gamma1_best,gamma2_best)
    #     # save_str = '{}__error__gamma1_gamma2'.format(data_file_list[ii])
    #     # plot_error_mat(error_mat_1[:,:],gamma1_vec,gamma2_vec,'gamma1','gamma2',title_string,save_str)

    #     # repeat best one and plot
    #     sim_params = dict()
    #     sim_params['gamma1'] = gamma1_best
    #     sim_params['gamma2'] = gamma2_best
    #     sim_params['gamma3'] = gamma3_best
    #     sim_params['dt'] = dt
    #     sim_params['tf'] = tf

    #     # initialize synapse
    #     name_s = 'sy'
    #     synapse_1 = synapse(name_s, integration_loop_temporal_form = 'exponential', integration_loop_time_constant = tau_si_vec[ii],
    #                         integration_loop_self_inductance = L_si_vec[ii], integration_loop_output_inductance = 0e-12,
    #                         synaptic_bias_current = I_sy_vec[ii], integration_loop_bias_current = 35e-6,
    #                         input_signal_name = 'in', synapse_model_params = sim_params)

    #     synapse_1.run_sim()

    #     actual_data = np.vstack((synapse_1.time_vec[:],synapse_1.I_si[:,0]))
    #     error__si = chi_squared_error(target_data,actual_data)

    #     main_title = '{}\ngamma1_best = {:6.4f}, gamma2_best = {:6.4f}\ngamma3_best_{:6.4f}'.format(data_file_list[ii],gamma1_best,gamma2_best,gamma3_best)
    #     plot_wr_comparison__synapse(main_title,spike_times,wr_drive,target_data,actual_data,file_name,error__si)

    # # save data
    # save_string = 'wr_fits__finding_gamma1_gamma2'
    # data_array['wr_spice_data_file_list'] = data_file_list
    # data_array['best_params'] = best_params
    # data_array['error_mat_master__gamma1_gamma2'] = error_mat_master__gamma1_gamma2
    # print('\n\nsaving session data ...')
    # save_session_data(data_array,save_string)

    #plot errors
    # title_string = '{}; gamma3_best = {:6.4f}'.format(master_error_plot_name,gamma3_best)
    # save_str_1 = '{}__master_error__gamma3'.format(master_error_plot_name)
    # plot_error_vec(error_mat_master__gamma1_gamma2[:,:],gamma1_vec,gamma2_vec,'gamma1','gamma2',title_string,save_str_1)

    return best_params, error_mat_master__gamma1_gamma2, error_mat_master__gamma3
Exemplo n.º 5
0
                                   error_mat_master__gamma1_gamma2__vary_Lsi +
                                   error_mat_master__gamma1_gamma2__vary_tausi)

error_mat_master__gamma3 = (error_mat_master__gamma3__vary_Isy +
                            error_mat_master__gamma3__vary_Lsi +
                            error_mat_master__gamma3__vary_tausi)

ind_best = np.where(error_mat_master__gamma1_gamma2 == np.amin(
    error_mat_master__gamma1_gamma2))  #error_mat.argmin()
gamma1_best = gamma1_vec[ind_best[0]][0]
gamma2_best = gamma2_vec[ind_best[1]][0]

ind_best = error_mat_master__gamma3.argmin()
gamma3_best = gamma3_vec[ind_best]

title_string = 'Master Error; gamma1_best = {:7.4f}, gamma2_best = {:7.4f}, gamma3_best = {:7.4f}'.format(
    gamma1_best, gamma2_best, gamma3_best)
save_str_1 = 'sy__master_error__gamma1_gamma2'
save_str_2 = 'sy__master_error__gamma3'

plot_error_mat(error_mat_master__gamma1_gamma2[:, :], gamma1_vec, gamma2_vec,
               'gamma1', 'gamma2', title_string, save_str_1)

fig, ax = plt.subplots(1, 1)
fig.suptitle(title_string)
ax.plot(gamma3_vec, error_mat_master__gamma3)
ax.set_xlabel(r'$\gamma_3$')
ax.set_ylabel(r'Error')
plt.show()
fig.savefig('figures/' + save_str_2 + '.png')