def plot_optimisation_results(candidate, fitness, max_evaluations, pop_size):
    """
    Plot a comparison of the model, the experimental data and the
    original model by Rothman (Schwartz2012) across all experimental
    traces. The figure gets saved to disk in the current folder.
    """
    fig, ax = plt.subplots(nrows=8, ncols=4, figsize=(160,80), dpi=500)
    rothman_fitness = 0

    java_fit_time_points = []
    java_fit_signals = []
    with h5py.File("AMPA_jason_fit_traces.hdf5") as java_fit_data_repo:
        for freq in problem.frequencies:
            for prot in range(problem.n_protocols):
                data_group = java_fit_data_repo['/{0}/{1}'.format(freq, prot)]
                java_fit_time_points.append(np.array(data_group['average_waveform'][:-1,0]))
                java_fit_signals.append(np.array(data_group['average_waveform'][:-1,1]))

    for k, ep in enumerate(problem.exp_pulses):
        timepoints = problem.exp_data[k][:,0]
        signal_direct = synthetic_conductance_signal_direct(timepoints,
                                                            ep,
                                                            problem.single_waveform_lengths[k],
                                                            problem.timestep_sizes[k],
                                                            0.54,
                                                            *candidate[0:7])
        signal_spillover = synthetic_conductance_signal_spillover(timepoints,
                                                                  ep,
                                                                  problem.single_waveform_lengths[k],
                                                                  problem.timestep_sizes[k],
                                                                  0.54,
                                                                  *candidate[7:])
        rothman_signal = java_fit_signals[k]
        rothman_fitness += np.linalg.norm(rothman_signal - problem.exp_data[k][:,1])/np.sqrt(timepoints.shape[0])
        ax.flat[k].plot(timepoints, problem.exp_data[k][:,1], color='k', linewidth=3)
        ax.flat[k].scatter(ep, np.zeros(shape=ep.shape)-0.05, color='k')
        ax.flat[k].plot(timepoints, rothman_signal, linewidth=1, color='r')
        ax.flat[k].plot(timepoints, signal_direct+signal_spillover, linewidth=1, color='g')
        #ax.flat[k].plot(timepoints, signal_direct, color='r')
        #ax.flat[k].plot(timepoints, signal_spillover, color='c')
    rothman_fitness /= len(problem.exp_pulses)
    fig.suptitle('parameters: {0}\n fitness: {1} max_evaluations: {2} pop_size: {3}\nRothman2012 fitness: {4}'.format(candidate, fitness, max_evaluations, pop_size, rothman_fitness))
    plt.savefig('Rothman_AMPA_TM_fit_{0}.png'.format(time.time()))
def scale_to_sargent(candidate):
    """Scale fit values to match peak AMPA reported in Sargent2005."""
    sargent_peak = 0.63 # (nS)

    timestep = 0.01
    timepoints = np.arange(0, 300, timestep)
    pulse_times = np.array([10.])
    single_waveform_length = timepoints.shape[0]
    signal_direct = synthetic_conductance_signal_direct(timepoints,
                                                        pulse_times,
                                                        single_waveform_length,
                                                        timestep,
                                                        0.,
                                                        *candidate[:7])
    signal_spillover = synthetic_conductance_signal_spillover(timepoints,
                                                              pulse_times,
                                                              single_waveform_length,
                                                              timestep,
                                                              0.,
                                                              *candidate[7:])
    signal = signal_direct + signal_spillover
    scaling_factor = sargent_peak/signal.max()
    scaled_signal = scaling_factor * signal
    scaled_candidate = candidate[:]
    for k in [1,2,8,9,10]:
        # scale 'amplitude' parameters
        scaled_candidate[k] *= scaling_factor
    print(scaled_candidate)
    # rounded scaled candidate should be [0.3274, 3.724, 0.3033,
    # 0.3351, 1.651, 0.1249, 131.0, 0.5548, 0.2487, 0.2799, 0.1268,
    # 0.4, 4.899, 43.1, 0.2792, 14.85]

    fig, ax = plt.subplots()
    ax.plot(timepoints, signal, label="fit to Rothman data")
    ax.plot(timepoints, scaled_signal, label="scaled to peak value by Sargent")
    ax.legend(loc="best")
    fig.suptitle("direct {0}\nspillover {1}".format(scaled_candidate[:7],
                                                    scaled_candidate[7:]))
    plt.show()
def plot_single_comparison(filename, candidate, pulse_n, timepoints, trace, timestep=0.025, delay=0):
    pulse_times = problem.exp_pulses[pulse_n]
    single_waveform_length = problem.single_waveform_lengths[8]

    signal_direct = synthetic_conductance_signal_direct(timepoints,
                                                        pulse_times,
                                                        single_waveform_length,
                                                        timestep,
                                                        delay,
                                                        *candidate[:7])
    signal_spillover = synthetic_conductance_signal_spillover(timepoints,
                                                              pulse_times,
                                                              single_waveform_length,
                                                              timestep,
                                                              delay,
                                                              *candidate[7:])

    fig, ax = plt.subplots()
    # plot trace we are comparing
    ax.plot(timepoints, trace, linewidth=2.5, color='k')
    # plot model trace
    ax.plot(timepoints, signal_direct+signal_spillover, linewidth=2.5, color='r')
    # plot spike raster
    displacement = 0. # nS
    ax.scatter(pulse_times,
               np.zeros_like(pulse_times)-displacement,
               marker="o",
               color='k')
    ax.set_xlabel('time (ms)')
    ax.set_ylabel('amplitude (a.u.)')
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    ax.spines['top'].set_color('none')
    ax.spines['right'].set_color('none')
    plt.savefig(filename,
                dpi=100,
                size=(5,10))
    plt.show()
def fitness_to_experiment(cs):
    # fitness is defined as the average (across experimental datasets)
    # of the L2 norm between the synthetic and the experimental signal
    # divided by the square root of the number of time points, to
    # avoid unfair weighing in favour of longer recordings (in other
    # words, this should give equal weight to every time point in
    # every recording).
    distances = []
    for k, ep in enumerate(problem.exp_pulses):
        timepoints = problem.exp_data[k][:,0]
        signal_direct = synthetic_conductance_signal_direct(timepoints,
                                                            ep,
                                                            problem.single_waveform_lengths[k],
                                                            problem.timestep_sizes[k],
                                                            0.54,
                                                            *cs[:7])
        signal_spillover = synthetic_conductance_signal_spillover(timepoints,
                                                                  ep,
                                                                  problem.single_waveform_lengths[k],
                                                                  problem.timestep_sizes[k],
                                                                  0.54,
                                                                  *cs[7:])
        distances.append(np.linalg.norm(signal_direct+signal_spillover-problem.exp_data[k][:,1])/np.sqrt(timepoints.shape[0]))
    return sum(distances)/len(distances)