def compare_correlated_summary_stats(params, syn_current, pair, corr_mat, mag_of_change, selected_indices, using="C++", mode="auto", summary_func=None): """Simulates a reference and a correlated trace. Summarises both and returns an DataFrame containing the selected summary statistics. Parameters ---------- params : torch.Tensor A 1D parameter tensor that holds the values of input parameters to the simulator model. Contains the initial current "V0" [float], the stimulation current trace "It" [ndarray], the product of the stimulation current trace and the area of the soma "ItA": ndarray, the time axis "ts": ndarray, the step size "dt": float and the simulus time course "StimOnset": float, "StimEnd": float and "StimDur": float. pair : tuple, list, numpy.ndarray Specifies which correlation axis to change the parameters along. corr_mat : torch.Tensor Matrix containing the pairwise correlation coefficients of the different parameters. mag_of_change : float Value from the intervall of [-k, k], where k \in R that specifies by how much the input parameters are altered. selected_stats : None, tuple, list or array Indices of summary stats to be used. sumary_func : function or str A function that takes voltage trace V(t), current trace I(t), time points t and time step dt and outputs a dictionary with summary stats and descriptions. Alternatively a descriptive string like "v1" or "v2" can be provided to use one of the summary methods already implemented as part of Trace(). mode : str Specifies whether cell/experiment specific parameters like ASoma and V0 should be specified as part of the model parameters ("model_params") or if ASoma and V0 should be considered part of the syn_current or trace object ("trace"). Returns ------- sum_df : pandas.DataFrame Contains the summary statistics for both the base and correlated trace.""" # generate correlated parameters pars = generate_correlated_parameters(params, corr_mat, pair=pair, mag_of_change=mag_of_change) # run simulation with reference and correlated params corr_trace = runHH(pars, syn_current=syn_current, using=using, mode=mode) base_trace = runHH(params, syn_current=syn_current, using=using, mode=mode) # collect results of summaries in DataFrame sum_corr_trace_df = pd.Series( corr_trace.summarise(selected_indices, summary_func=summary_func), name="correlated params") sum_base_trace_df = pd.Series(base_trace.summarise(selected_indices,summary_func=summary_func), name="base params") sum_df = pd.concat([sum_base_trace_df,sum_corr_trace_df], axis=1) return sum_df
def plot_best_matches(thetas=None, stats_sim=None, stats_obs=None, simulated_traces=None, trace_obs=None, samplesize=10, figsize=(10, 14), savefig=False, fig_name="default", selected_stats=None, metric="MSE", mode="auto"): """Plot the best matching voltage traces compared with the observed trace, based on the mean squared error of the summary statistics. Parameters ---------- thetas : torch.tensor or numpy.ndarray Holds the parameters that go along with stats_sim. stats_sim = pandas.DataFrame Contains the summarised simulation results. stats_obs = pandas.DataFrame Contains the summarised observation results. simulated_traces : list[Trace] The simulated data of the voltage traces in the form of Trace() objects. trace_obs : Trace Trace object for the observed voltage trace selected_stats : None, tuple, list or array Indices of summary stats to be used samplesize : int How many simulated traces to compare. figsize : tuple, list or ndarray Changes the figure size of the plot. timewindow: tuple, list or ndarray Contains start and end time of voltage traces that will be shown. savefig : bool Determines whether or not to save the plotted figure. fig_name : str Integrates figure name into the filename of the saved figure. selected_stats : None, tuple, list or array Indices of summary stats to be used. metric : str Determines which metric to sort the samples by. Currently "MSE" and "std" are supported. mode : str Specifies whether cell/experiment specific parameters like ASoma and V0 should be specified as part of the model parameters ("model_params") or if ASoma and V0 should be considered part of the syn_current or trace object ("trace").""" mins_idx, mins = best_matches(stats_sim, stats_obs, simulated_traces, trace_obs, selected_stats, metric) if type(simulated_traces) == type(None): if type(thetas) == pd.core.frame.DataFrame: thetas = torch.tensor(thetas.values[mins_idx[:samplesize]]) simulated_traces = runHH(thetas, trace_obs, mode=mode) if trace_obs == None: raise ValueError("No observed trace was provided. Please provide Trace().") plot_comparison(simulated_traces, trace_obs, samplesize, figsize=(10, 14)) if savefig: plt.savefig('../data/figures/{}_8paramprior_support_best_sims_v_obs.png'.format(fig_name))
def show_correlated_traces(params, syn_current, corr_mat, pair=(0,1), start=0.01, stop=0.5, N=5, figsize=(15,10), timewindow=[0.3,0.33], compare_changes=False, savefig=False, fig_name="default", mode="auto"): """Plots the effects of changing correlated parameters, for one parameter pair, on the summary stats of the simulated voltage traces for different magnitudes of change. Parameters ---------- params : torch.Tensor A 1D parameter tensor that holds the values of input parameters to the simulator model. syn_current : dict Contains the initial current "V0" [float], the stimulation current trace "It" [ndarray], the product of the stimulation current trace and the area of the soma "ItA": ndarray, the time axis "ts": ndarray, the step size "dt": float and the simulus time course "StimOnset": float, "StimEnd": float and "StimDur": float. corr_mat : torch.Tensor Matrix containing the pairwise correlation coefficients of the different parameters. pair : tuple, list, numpy.ndarray Specifies which correlation axis to change the parameters along. start : float Minimum value of parameter change along correlation axis. stop : float Maximum value of parameter change along correlation axis. N : int Number of values inbetween min. and max. amount of change. figsize : tuple Specify size of each plt.figure. timewindow : tuple, list or ndarray The voltage and current trace will only be plotted between (t1,t2). To be specified in secs. compare_changes : bool Whether or not to include a plot that compares the correlated traces for all magnitudes of change. savefig : bool Determines whether or not to save the plotted figure. fig_name : str Integrates figure name into the filename of the saved figure. mode : str Specifies whether cell/experiment specific parameters like ASoma and V0 should be specified as part of the model parameters ("model_params") or if ASoma and V0 should be considered part of the syn_current or trace object ("trace").""" base_trace = runHH(params, syn_current=syn_current, mode=mode) corr_traces = [] fig = plt.figure(figsize=figsize) for idx, change in enumerate(np.linspace(start, stop, N)): # generate correlated params, simulate and store trace pars = generate_correlated_parameters(params, corr_mat, pair=pair, mag_of_change=change) corr_trace = runHH(pars, syn_current=syn_current) corr_traces.append((change,corr_trace)) # plot both traces onto same axis axs = base_trace.inspect(voltage_only=True, timewindow=timewindow, label="base params", line_color="grey") corr_trace.inspect(axes=axs, voltage_only=True, timewindow=timewindow, label="correlated params", line_color=colormap(idx/N)) plt.ylim(min(base_trace.Vt*1e3),max(base_trace.Vt*1e3)) plt.legend(loc=1) plt.title("change along correlation axis = {0:.1f}%".format(change*100)) if savefig: plt.savefig('../data/figures/{}_8paramprior_support_corr_change_effect_on_traces_{}.png'.format(fig_name, pair)) # plots all changes into one set ox axes if compare_changes: axs = base_trace.inspect(voltage_only=True, timewindow=timewindow, label="base params", line_color="grey") for idx, (change,corr_trace) in enumerate(corr_traces): corr_trace.inspect(axes=axs, voltage_only=True, timewindow=timewindow, label="{0:.1f}%".format(change*100), line_color=colormap(idx/N)) plt.legend(loc=1) if savefig: plt.savefig('../data/figures/{}_8paramprior_support_corr_change_effect_on_traces_combined_{}.png'.format(fig_name, pair))
def plot_correlated_summary_stats(params, syn_current, corr_mat, pair=(0,1), start=0.01, stop=0.5, N=5, figsize=(15,10), timewindow=[0.3,0.33], selected_stats=None, change_only=False, savefig=False, fig_name="default", mode="auto", summary_func="v1"): """Plots the effects of changing correlated parameters, for one parameter pair, on the summary stats of the simulated voltage traces for different magnitudes of change. Parameters ---------- params : torch.Tensor A 1D parameter tensor that holds the values of input parameters to the simulator model. syn_current : dict Contains the initial current "V0" [float], the stimulation current trace "It" [ndarray], the product of the stimulation current trace and the area of the soma "ItA": ndarray, the time axis "ts": ndarray, the step size "dt": float and the simulus time course "StimOnset": float, "StimEnd": float and "StimDur": float. corr_mat : torch.Tensor Matrix containing the pairwise correlation coefficients of the different parameters. pair : tuple, list, numpy.ndarray Specifies which correlation axis to change the parameters along. start : float Minimum value of parameter change along correlation axis. stop : float Maximum value of parameter change along correlation axis. N : int Number of values inbetween min. and max. amount of change. figsize : tuple Specify size of each plt.figure. timewindow : tuple, list or ndarray The voltage and current trace will only be plotted between (t1,t2). To be specified in secs. selected_stats : None, tuple, list or array Indices of summary stats to be used. change_only : bool Whether or not to only plot the change in the difference of summary stats. savefig : bool Determines whether or not to save the plotted figure. fig_name : str Integrates figure name into the filename of the saved figure. mode : str Specifies whether cell/experiment specific parameters like ASoma and V0 should be specified as part of the model parameters ("model_params") or if ASoma and V0 should be considered part of the syn_current or trace object ("trace").""" base_trace = runHH(params, syn_current=syn_current, mode=mode) base_trace.summarise(selected_stats) change_hist = pd.DataFrame(data=None,columns=base_trace.Summary.keys()) for idx, change in enumerate(np.linspace(start, stop, N)): sum_df = compare_correlated_summary_stats(params, syn_current, pair, corr_mat, change, selected_stats, mode=mode, summary_func=summary_func) rel_changes = abs((sum_df["base params"] - sum_df["correlated params"])/sum_df["base params"])*100 changes_df = pd.DataFrame(rel_changes, columns=[r"$\Delta_{corr}$"+" = {}%".format(change*100)]) change_hist = change_hist.append(changes_df.T) if not change_only: plot_change_of_corr_summaries(sum_df, changes_df, colormap(idx/N), savefig, fig_name) change_hist.T.plot(kind="bar", figsize=(15,5), cmap=colormap) plt.ylim(0,50) if savefig: plt.savefig('../data/figures/{}_8paramprior_support_corr_change_effect_on_summaries_change.png'.format(fig_name))
def plot_correlation_effects(params, syn_current, corr_mat, mag_of_change=0.1, figsize=(15,10), timewindow=[0.3,0.33], effect_on="traces", selected_stats=None, savefig=False, fig_name="default", summary_func="v1", mode="auto"): """Plots the effects of changing correlated parameter pairs, for all pairs, on the simulated voltage traces or the change in summary statistics. Starts with (0,1). Parameters ---------- params : torch.Tensor A 1D parameter tensor that holds the values of input parameters to the simulator model. syn_current : dict Contains the initial current "V0" [float], the stimulation current trace "It" [ndarray], the product of the stimulation current trace and the area of the soma "ItA": ndarray, the time axis "ts": ndarray, the step size "dt": float and the simulus time course "StimOnset": float, "StimEnd": float and "StimDur": float. corr_mat : torch.Tensor Matrix containing the pairwise correlation coefficients of the different parameters. mag_of_change : float Value from the intervall of [-k, k], where k \in R that specifies by how much the input parameters are altered. figsize : tuple Specify size of plt.figure. timewindow : tuple, list or ndarray The voltage and current trace will only be plotted between (t1,t2). To be specified in secs. effect_on : str ("traces" or "summary stats") Specify how the effect of a correlated change in parameters affects either the summary statistics or the voltage traces. selected_stats : None, tuple, list or array Indices of summary stats to be used. savefig : bool Determines whether or not to save the plotted figure. fig_name : str Integrates figure name into the filename of the saved figure. summary_func : str A descriptive string like "v1" or "v2" can be provided to use one of the summary methods already implemented as part of Trace(). mode : str Specifies whether cell/experiment specific parameters like ASoma and V0 should be specified as part of the model parameters ("model_params") or if ASoma and V0 should be considered part of the syn_current or trace object ("trace").""" N = np.max(params.shape) fig, axes = plt.subplots(N-1,N-1, figsize=figsize) # loop over pairs in upper triangle of matrix for j in range(1,N): for i in range(j): pars = generate_correlated_parameters(params, corr_mat, pair=(i,j)) correlated_trace = runHH(pars, syn_current=syn_current, mode=mode) base_trace = runHH(params, syn_current=syn_current, mode=mode) axs = axes[i,j-1] if "trace" in effect_on.lower(): # plot traces base_trace.inspect(axes=axs, voltage_only=True, timewindow=timewindow) correlated_trace.inspect(axes=axs, voltage_only=True, timewindow=timewindow) # format axes if i != j-1: axs.get_xaxis().set_ticks([]) axs.get_yaxis().set_ticks([]) axs.set_xticklabels([]) axs.set_yticklabels([]) axs.set_xlabel(None) axs.set_ylabel(None) plt.subplots_adjust(hspace = .65, wspace = .2) else: if i == 0: axs.set_ylabel("V [mV]", fontsize=10) axs.set_xlabel("j = "+str(j), fontsize=10) if j == 7: axs.set_xlabel("t [ms]", fontsize=10) axs.set_ylabel("i = "+ str(i), fontsize=10) if i != 0 and j != 7: axs.set_xlabel("j = "+ str(j), fontsize=10) axs.set_ylabel("i =" + str(i), fontsize=10) if "summary" in effect_on.lower(): summary_df = compare_correlated_summary_stats(params, syn_current, (i,j), corr_mat, mag_of_change, selected_stats, summary_func=summary_func, mode=mode) changes_df = pd.DataFrame(abs((summary_df["base params"] - summary_df["correlated params"])/summary_df["base params"])*100, columns=[r"$\Delta_{corr}$"+" = {}\%".format(mag_of_change*100)]) # plot changes changes_df.plot(kind="bar", ax=axs, color="black",legend=False, width=1) # format axes and plots axs.set_xticklabels([]) axs.set_yticklabels([]) axs.get_xaxis().set_ticks([]) axs.get_yaxis().set_ticks([]) axs.set_ylim([0,100]) plt.subplots_adjust(hspace = .65, wspace = .2) plt.suptitle("Percentage change in the summary statistics") if i == j-1: axs.set_xlabel("j = "+ str(j), fontsize=10) axs.set_ylabel("i =" + str(i), fontsize=10) axs.get_yaxis().set_ticks([0,100]) axs.set_yticklabels([0,100]) if i == 0 and j == 1: axs.set_ylabel(r"$\Delta$ [%]", fontsize=10) axs.set_xlabel("j = "+str(j), fontsize=10) if j == N-1 and i == N-2: axs.set_xlabel("Summary Stats", fontsize=10) axs.set_ylabel("i = "+ str(i), fontsize=10) else: axs.set_xlabel("j = "+ str(j), fontsize=10) axs.set_ylabel("i =" + str(i), fontsize=10) # remove diagonal and lowever triangle of matrix for i in range(1,N): for j in range(i-1): axs = axes[i-1,j] axs.get_xaxis().set_ticks([]) axs.get_yaxis().set_ticks([]) sns.despine(ax = axs, left = True, bottom = True) plt.suptitle("") if savefig: plt.savefig('../data/figures/{}_8paramprior_support_corr_change_effect_on_{}.png'.format(fig_name, effect_on))