Ejemplo n.º 1
0
def model_fit(model, ds, bold_transient=10000, fc=True, fcd=False):
    result = {}
    if fc:
        result["fc_scores"] = [
            func.matrix_correlation(
                func.fc(model.BOLD.BOLD[:,
                                        model.BOLD.t_BOLD > bold_transient]),
                fc) for i, fc in enumerate(ds.FCs)
        ]
        result["mean_fc_score"] = np.mean(result["fc_scores"])

    if fcd:
        fcd_sim = func.fcd(model.BOLD.BOLD[:,
                                           model.BOLD.t_BOLD > bold_transient])
        # if the FCD dataset is already computed, use it
        if hasattr(ds, "FCDs"):
            fcd_scores = [
                func.matrix_kolmogorov(
                    fcd_sim,
                    fcd_emp,
                ) for fcd_emp in ds.FCDs
            ]
        else:
            fcd_scores = [
                func.ts_kolmogorov(
                    model.BOLD.BOLD[:, model.BOLD.t_BOLD > bold_transient],
                    bold) for bold in ds.BOLDs
            ]
        fcd_meanScore = np.mean(fcd_scores)

        result["fcd"] = fcd_scores
        result["mean_fcd"] = fcd_meanScore

    return result
Ejemplo n.º 2
0
def plot_outputs(model,
                 ds=None,
                 activity_xlim=None,
                 bold_transient=10000,
                 spectrum_windowsize=1,
                 plot_fcd=None):

    # check if BOLD signal is long enough for FCD
    FCD_THRESHOLD = 60  # seconds
    # plot_fcd = False
    if "BOLD" in model.outputs and plot_fcd is None:
        if len(model.BOLD.BOLD.T
               ) > FCD_THRESHOLD / 2:  # div by 2 because of bold sampling rate
            plot_fcd = True

    nrows = 2
    if plot_fcd:
        nrows += 1
    fig, axs = plt.subplots(nrows, 3, figsize=(12, nrows * 3), dpi=150)

    if "t" in model.outputs:
        axs[0, 0].set_ylabel("Activity")
        axs[0, 0].set_xlabel("Time [s]")
        axs[0, 0].plot(model.outputs.t / 1000,
                       model.output.T,
                       alpha=0.8,
                       lw=0.5)
        axs[0, 0].plot(model.outputs.t / 1000,
                       np.mean(model.output, axis=0),
                       c="r",
                       alpha=0.8,
                       lw=1.5,
                       label="average",
                       zorder=3)

        axs[0, 0].plot(
            model.outputs.t / 1000,
            np.mean(model.output[1::2, :], axis=0),
            c="k",
            alpha=0.8,
            lw=1,
            label="L average",
        )

        axs[0, 0].plot(
            model.outputs.t / 1000,
            np.mean(model.output[::2, :], axis=0),
            c="k",
            alpha=0.8,
            lw=1,
            label="R average",
        )

        axs[0, 0].set_xlim(activity_xlim)

        axs[0, 1].set_ylabel("Node")
        axs[0, 1].set_xlabel("Time [s]")
        # plt.imshow(rates_exc*1000, aspect='auto', extent=[0, params['duration'], N, 0], clim=(0, 10))
        axs[0, 1].imshow(
            model.output,
            aspect="auto",
            extent=[0, model.t[-1] / 1000, model.params.N, 0],
            clim=(0, 10),
        )

        # output frequency spectrum
        axs[0, 2].set_ylabel("Power")
        axs[0, 2].set_xlabel("Frequency [Hz]")
        for o in model.output:
            frs, pwrs = getPowerSpectrum(
                o, dt=model.params.dt, spectrum_windowsize=spectrum_windowsize)
            axs[0, 2].plot(frs, pwrs, alpha=0.8, lw=0.5)
        frs, pwrs = getMeanPowerSpectrum(
            model.output,
            dt=model.params.dt,
            spectrum_windowsize=spectrum_windowsize)
        axs[0, 2].plot(frs, pwrs, lw=3, c="springgreen")

        ## frequency spectrum annotations
        peaks = scipy.signal.find_peaks_cwt(pwrs, np.arange(2, 3))
        for p in peaks:
            axs[0, 2].scatter(frs[p], pwrs[p], c="springgreen", zorder=20)
            # p = np.argmax(Pxxs)
            axs[0, 2].annotate(s="  {0:.1f} Hz".format(frs[p]),
                               xy=(frs[p], pwrs[p]),
                               fontsize=10)

    if "BOLD" in model.outputs:
        # BOLD plotting ----------
        axs[1, 0].set_ylabel("BOLD")
        axs[1, 0].set_xlabel("Time [s]")
        t_bold = model.outputs.BOLD.t_BOLD[
            model.outputs.BOLD.t_BOLD > bold_transient] / 1000
        bold = model.outputs.BOLD.BOLD[:, model.outputs.BOLD.
                                       t_BOLD > bold_transient]
        axs[1, 0].plot(t_bold, bold.T, lw=1.5, alpha=0.8)

        axs[1, 1].set_title("FC", fontsize=12)
        if ds is not None:
            fc_fit = model_fit(model, ds, bold_transient,
                               fc=True)["mean_fc_score"]
            axs[1, 1].set_title(f"FC (corr: {fc_fit:0.2f})", fontsize=12)
        axs[1, 1].imshow(func.fc(bold), origin="upper")
        axs[1, 1].set_ylabel("Node")
        axs[1, 1].set_xlabel("Node")

        axs[1, 2].set_title("FC corr over time", fontsize=12)
        axs[1, 2].plot(
            np.arange(4, bold.shape[1] * 2, step=2),
            np.array([[
                func.matrix_correlation(func.fc(bold[:, :t]), fc)
                for t in range(2, bold.shape[1])
            ] for fc in ds.FCs]).T,
        )
        axs[1, 2].set_ylabel("FC fit")
        axs[1, 2].set_xlabel("Simulation time [s]")

        # FCD plotting ------------
        if plot_fcd:
            # plot image of fcd
            axs[2, 0].set_title("FCD", fontsize=12)
            axs[2, 0].set_ylabel("$n_{window}$")
            axs[2, 0].set_xlabel("$n_{window}$")
            axs[2, 0].imshow(func.fcd(bold), origin="upper")

            # plot distribution in fcd
            fcd_fit = model_fit(model, ds, bold_transient,
                                fcd=True)["mean_fcd"]
            axs[2, 1].set_title(f"FCD distance {fcd_fit:0.2f}", fontsize=12)
            axs[2, 1].set_ylabel("P")
            axs[2, 1].set_xlabel("triu(FCD)")
            m1 = func.fcd(bold)
            triu_m1_vals = m1[np.triu_indices(m1.shape[0], k=1)]
            axs[2, 1].hist(triu_m1_vals,
                           density=True,
                           color="springgreen",
                           zorder=10,
                           alpha=0.6)
            # plot fcd distributions of data
            if hasattr(ds, "FCDs"):
                for emp_fcd in ds.FCDs:
                    m1 = emp_fcd
                    triu_m1_vals = m1[np.triu_indices(m1.shape[0], k=1)]
                    axs[2, 1].hist(triu_m1_vals, density=True, alpha=0.5)

            # temp bullshit
            axs[2, 2].plot(model.outputs.rates_exc[0, :],
                           model.outputs.rates_inh[0, :],
                           lw=0.5)
            axs[2, 2].set_xlabel("$r_{exc}$")
            axs[2, 2].set_ylabel("$r_{inh}$")

    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
Ejemplo n.º 3
0
 def test_fcd(self):
     rFCD = func.fcd(self.model.rates_exc, stepsize=100)