示例#1
0
def model_fit(model, ds, bold_transient=10000, fc=True, fcd=False):
    result = {}
    if fc:
        result["fc_scores"] = [
            func.matrix_correlation(
                func.fc(model.BOLD.BOLD[:,
                                        model.BOLD.t_BOLD > bold_transient]),
                fc) for i, fc in enumerate(ds.FCs)
        ]
        result["mean_fc_score"] = np.mean(result["fc_scores"])

    if fcd:
        fcd_sim = func.fcd(model.BOLD.BOLD[:,
                                           model.BOLD.t_BOLD > bold_transient])
        # if the FCD dataset is already computed, use it
        if hasattr(ds, "FCDs"):
            fcd_scores = [
                func.matrix_kolmogorov(
                    fcd_sim,
                    fcd_emp,
                ) for fcd_emp in ds.FCDs
            ]
        else:
            fcd_scores = [
                func.ts_kolmogorov(
                    model.BOLD.BOLD[:, model.BOLD.t_BOLD > bold_transient],
                    bold) for bold in ds.BOLDs
            ]
        fcd_meanScore = np.mean(fcd_scores)

        result["fcd"] = fcd_scores
        result["mean_fcd"] = fcd_meanScore

    return result
示例#2
0
def plot_outputs(model, bold_transient=10000):
    _, axs = plt.subplots(2, 3, figsize=(12, 6))

    if "t" in model.outputs:
        axs[0, 0].plot(model.outputs.t, model.output.T)
    if "BOLD" in model.outputs:
        axs[1, 0].plot(
            model.outputs.BOLD.t_BOLD[
                model.outputs.BOLD.t_BOLD > bold_transient],
            model.outputs.BOLD.
            BOLD[:, model.outputs.BOLD.t_BOLD > bold_transient].T,
        )
        axs[1, 1].imshow(
            func.fc(model.outputs.BOLD.
                    BOLD[:, model.outputs.BOLD.t_BOLD > bold_transient]))
    plt.show()
        def evaluateSimulation(traj):
            model = search.getModelFromTraj(traj)
            defaultDuration = model.params["duration"]
            invalid_result = {"fc": [0] * len(ds.BOLDs)}

            # -------- stage wise simulation --------

            # Stage 1 : simulate for a few seconds to see if there is any activity
            # ---------------------------------------
            model.params["dt"] = 0.1
            model.params["duration"] = 3 * 1000.0
            model.run()

            # check if stage 1 was successful
            if np.max(model.rates_exc[:, model.t > 500]) > 300 or np.max(
                    model.rates_exc[:, model.t > 500]) < 10:
                search.saveOutputsToPypet(invalid_result, traj)
                return invalid_result, {}

            # Stage 2: simulate BOLD for a few seconds to see if it moves
            # ---------------------------------------
            model.params["dt"] = 0.2
            model.params["duration"] = 20 * 1000.0
            model.run(bold=True)

            if np.std(model.BOLD.BOLD[:, 5:10]) < 0.001:
                search.saveOutputsToPypet(invalid_result, traj)
                return invalid_result, {}

            # Stage 3: full and final simulation
            # ---------------------------------------
            model.params["dt"] = 0.2
            model.params["duration"] = defaultDuration
            model.run()

            # -------- evaluation here --------

            scores = []
            for i, fc in enumerate(ds.FCs):  # range(len(ds.FCs)):
                fc_score = func.matrix_correlation(
                    func.fc(model.BOLD.BOLD[:, 5:]), fc)
                scores.append(fc_score)

            meanScore = np.mean(scores)
            result_dict = {"fc": meanScore}

            search.saveOutputsToPypet(result_dict, traj)
示例#4
0
def model_fit(model, ds, bold_transient=10000, fc=True, fcd=False):
    result = {}
    if fc:
        result["fc_scores"] = [
            func.matrix_correlation(
                func.fc(model.BOLD.BOLD[:,
                                        model.BOLD.t_BOLD > bold_transient]),
                fc) for i, fc in enumerate(ds.FCs)
        ]
        result["mean_fc_score"] = np.mean(result["fc_scores"])

    if fcd:
        fcd_scores = [
            func.ts_kolmogorov(
                model.BOLD.BOLD[:, model.BOLD.t_BOLD > bold_transient],
                ds.BOLDs[i]) for i in range(len(ds.BOLDs))
        ]
        fcd_meanScore = np.mean(fcd_scores)

        result["fcd"] = fcd_scores
        result["mean_fcd"] = fcd_meanScore

    return result
示例#5
0
def plot_outputs(model,
                 ds=None,
                 activity_xlim=None,
                 bold_transient=10000,
                 spectrum_windowsize=1,
                 plot_fcd=None):

    # check if BOLD signal is long enough for FCD
    FCD_THRESHOLD = 60  # seconds
    # plot_fcd = False
    if "BOLD" in model.outputs and plot_fcd is None:
        if len(model.BOLD.BOLD.T
               ) > FCD_THRESHOLD / 2:  # div by 2 because of bold sampling rate
            plot_fcd = True

    nrows = 2
    if plot_fcd:
        nrows += 1
    fig, axs = plt.subplots(nrows, 3, figsize=(12, nrows * 3), dpi=150)

    if "t" in model.outputs:
        axs[0, 0].set_ylabel("Activity")
        axs[0, 0].set_xlabel("Time [s]")
        axs[0, 0].plot(model.outputs.t / 1000,
                       model.output.T,
                       alpha=0.8,
                       lw=0.5)
        axs[0, 0].plot(model.outputs.t / 1000,
                       np.mean(model.output, axis=0),
                       c="r",
                       alpha=0.8,
                       lw=1.5,
                       label="average",
                       zorder=3)

        axs[0, 0].plot(
            model.outputs.t / 1000,
            np.mean(model.output[1::2, :], axis=0),
            c="k",
            alpha=0.8,
            lw=1,
            label="L average",
        )

        axs[0, 0].plot(
            model.outputs.t / 1000,
            np.mean(model.output[::2, :], axis=0),
            c="k",
            alpha=0.8,
            lw=1,
            label="R average",
        )

        axs[0, 0].set_xlim(activity_xlim)

        axs[0, 1].set_ylabel("Node")
        axs[0, 1].set_xlabel("Time [s]")
        # plt.imshow(rates_exc*1000, aspect='auto', extent=[0, params['duration'], N, 0], clim=(0, 10))
        axs[0, 1].imshow(
            model.output,
            aspect="auto",
            extent=[0, model.t[-1] / 1000, model.params.N, 0],
            clim=(0, 10),
        )

        # output frequency spectrum
        axs[0, 2].set_ylabel("Power")
        axs[0, 2].set_xlabel("Frequency [Hz]")
        for o in model.output:
            frs, pwrs = getPowerSpectrum(
                o, dt=model.params.dt, spectrum_windowsize=spectrum_windowsize)
            axs[0, 2].plot(frs, pwrs, alpha=0.8, lw=0.5)
        frs, pwrs = getMeanPowerSpectrum(
            model.output,
            dt=model.params.dt,
            spectrum_windowsize=spectrum_windowsize)
        axs[0, 2].plot(frs, pwrs, lw=3, c="springgreen")

        ## frequency spectrum annotations
        peaks = scipy.signal.find_peaks_cwt(pwrs, np.arange(2, 3))
        for p in peaks:
            axs[0, 2].scatter(frs[p], pwrs[p], c="springgreen", zorder=20)
            # p = np.argmax(Pxxs)
            axs[0, 2].annotate(s="  {0:.1f} Hz".format(frs[p]),
                               xy=(frs[p], pwrs[p]),
                               fontsize=10)

    if "BOLD" in model.outputs:
        # BOLD plotting ----------
        axs[1, 0].set_ylabel("BOLD")
        axs[1, 0].set_xlabel("Time [s]")
        t_bold = model.outputs.BOLD.t_BOLD[
            model.outputs.BOLD.t_BOLD > bold_transient] / 1000
        bold = model.outputs.BOLD.BOLD[:, model.outputs.BOLD.
                                       t_BOLD > bold_transient]
        axs[1, 0].plot(t_bold, bold.T, lw=1.5, alpha=0.8)

        axs[1, 1].set_title("FC", fontsize=12)
        if ds is not None:
            fc_fit = model_fit(model, ds, bold_transient,
                               fc=True)["mean_fc_score"]
            axs[1, 1].set_title(f"FC (corr: {fc_fit:0.2f})", fontsize=12)
        axs[1, 1].imshow(func.fc(bold), origin="upper")
        axs[1, 1].set_ylabel("Node")
        axs[1, 1].set_xlabel("Node")

        axs[1, 2].set_title("FC corr over time", fontsize=12)
        axs[1, 2].plot(
            np.arange(4, bold.shape[1] * 2, step=2),
            np.array([[
                func.matrix_correlation(func.fc(bold[:, :t]), fc)
                for t in range(2, bold.shape[1])
            ] for fc in ds.FCs]).T,
        )
        axs[1, 2].set_ylabel("FC fit")
        axs[1, 2].set_xlabel("Simulation time [s]")

        # FCD plotting ------------
        if plot_fcd:
            # plot image of fcd
            axs[2, 0].set_title("FCD", fontsize=12)
            axs[2, 0].set_ylabel("$n_{window}$")
            axs[2, 0].set_xlabel("$n_{window}$")
            axs[2, 0].imshow(func.fcd(bold), origin="upper")

            # plot distribution in fcd
            fcd_fit = model_fit(model, ds, bold_transient,
                                fcd=True)["mean_fcd"]
            axs[2, 1].set_title(f"FCD distance {fcd_fit:0.2f}", fontsize=12)
            axs[2, 1].set_ylabel("P")
            axs[2, 1].set_xlabel("triu(FCD)")
            m1 = func.fcd(bold)
            triu_m1_vals = m1[np.triu_indices(m1.shape[0], k=1)]
            axs[2, 1].hist(triu_m1_vals,
                           density=True,
                           color="springgreen",
                           zorder=10,
                           alpha=0.6)
            # plot fcd distributions of data
            if hasattr(ds, "FCDs"):
                for emp_fcd in ds.FCDs:
                    m1 = emp_fcd
                    triu_m1_vals = m1[np.triu_indices(m1.shape[0], k=1)]
                    axs[2, 1].hist(triu_m1_vals, density=True, alpha=0.5)

            # temp bullshit
            axs[2, 2].plot(model.outputs.rates_exc[0, :],
                           model.outputs.rates_inh[0, :],
                           lw=0.5)
            axs[2, 2].set_xlabel("$r_{exc}$")
            axs[2, 2].set_ylabel("$r_{inh}$")

    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
示例#6
0
 def test_matrix_kolmogorov(self):
     func.matrix_kolmogorov(func.fc(self.model.rates_exc[::20, :]),
                            func.fc(self.model.rates_exc[::20, :]))
示例#7
0
 def test_matrix_correlation(self):
     FC = func.fc(self.model.BOLD.BOLD)
     cc = func.matrix_correlation(FC, self.ds.FCs[0])
示例#8
0
 def test_fc(self):
     FC = func.fc(self.model.BOLD.BOLD)
示例#9
0
def evaluate_model(model, cmat, path, fname, fc_real=False):
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    import numpy as np
    import pandas as pd
    import seaborn as sns
    import mne
    from fastdtw import fastdtw
    from neurolib.utils import functions as func
    import brainplot as bp
    import xarray as xr
    from neurolib.utils.signal import Signal
    from connectivity import make_graph, graph_measures

    plt.figure(figsize=(15, 5))
    plt.imshow(
                model.output, aspect="auto", extent=[0, model.t[-1] / 1000,
                                                     model.params.N, 0],
                clim=(0, 20), cmap="viridis",
            )
    cbar = plt.colorbar(extend='max', fraction=0.046, pad=0.04)
    cbar.set_label("Rate $r_{exc}$ [Hz]")
    plt.ylabel("Node")
    plt.xlabel("Time [s]")
    plt.tight_layout()
    plt.savefig(path / f"{fname}_ts.png", dpi=100)
    plt.close()

    data = model.rates_exc

    # ----------------- PLV ----------------- #
    con, _, _, _, _ = mne.connectivity.spectral_connectivity(
        np.split(data, 12, axis=1), method='plv',
        sfreq=10000, fmin=(0, 4, 8, 13, 30),
        fmax=(4, 8, 12, 30, 70), faverage=True
        )

    fig, ax = plt.subplots(1, 5, figsize=(20, 10), sharey=True)

    all_freq_bands = ("0-4Hz", "4-8Hz", "8-12Hz", "13-30Hz", "30-70Hz")
    for i, (_ax, freq_label) in enumerate(zip(ax, all_freq_bands)):
        im = _ax.imshow(con[:, :, i], clim=(0, 1), cmap="viridis")
        _ax.set_title(f'Frequency band: {freq_label}')
        if i == 0:
            _ax.set_ylabel("Node")
        _ax.set_xlabel("Node")
        plt.tight_layout()

    divider = make_axes_locatable(_ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)

    cbar = plt.colorbar(im, cax=cax)
    cbar.set_label("Phase-Locking Value")
    plt.tight_layout()
    plt.savefig(path / f"{fname}_plv_freq_bands.png", dpi=100)
    plt.close()

    # ----------------- DTW ----------------- #
    dtw = np.zeros((80, 80))

    for i, j in zip(np.tril_indices(80)[0], np.tril_indices(80)[1]):
        distance, _ = fastdtw(data[i, ::100][:10_000], data[j, ::100][:10_000])
        dtw[i, j] = distance
    dtw_norm = 1 - dtw/dtw.max()
    dtw_norm[np.triu_indices(80)] = dtw_norm.T[np.triu_indices(80)]

    plt.figure(figsize=(6, 6))
    plt.imshow(dtw_norm, clim=(0, 1), cmap="viridis")
    cbar = plt.colorbar(fraction=0.046, pad=0.04)
    cbar.set_label("Absolute distance")
    plt.ylabel("Node")
    plt.xlabel("Node")
    plt.title("Dynamic Time Warping")
    plt.tight_layout()
    plt.savefig(path / f"{fname}_dtw_norm.png", dpi=100)
    plt.close()

    # ----------------- FC ----------------- #

    plt.figure(figsize=(6, 6))
    plt.imshow(func.fc(model.BOLD.BOLD[:, 10:]), clim=(0, 1), cmap="viridis")
    cbar = plt.colorbar(fraction=0.046, pad=0.04)
    # cbar.set_label("Absolute distance")
    plt.ylabel("Node")
    plt.xlabel("Node")
    plt.title("BOLD FC")
    plt.tight_layout()
    plt.savefig(path / f"{fname}_bold_fc.png", dpi=100)
    plt.close()

    if not isinstance(fc_real, bool):
        plt.figure(figsize=(6, 6))
        plt.imshow(fc_real, clim=(0, 1), cmap="viridis")
        cbar = plt.colorbar(fraction=0.046, pad=0.04)
        # cbar.set_label("Absolute distance")
        plt.ylabel("Node")
        plt.xlabel("Node")
        plt.title("BOLD FC real data")
        plt.tight_layout()
        plt.savefig(path / f"{fname}_bold_fc_real.png", dpi=100)
        plt.close()

    # ----------------- Involvment ----------------- #
    states = bp.detectSWs(model, filter_long=True)

    # involvement =  1 - np.sum(states, axis=0) / states.shape[0]
    involvement = bp.get_involvement(states)

    inv_xr = xr.DataArray(involvement, coords=[model.t], dims=["time"])
    sig = Signal(inv_xr, time_in_ms=True)

    def get_phase(signal, filter_args, pad=None):
        """
        Extract phase of the signal. Steps: detrend -> pad -> filter -> Hilbert
        transform -> get phase -> un-pad.
        :param signal: signal to get phase from
        :type signal: `neuro_signal.Signal`
        :param filter_args: arguments for `Signal`'s filter method (see its
            docstring)
        :type filter_args: dict
        :param pad: how many seconds to pad, if None, won't pad
        :type pad: float|None
        :return: wrapped Hilbert phase of the signal
        :rtype: `neuro_signal.Signal`
        """
        assert isinstance(signal, Signal)
        phase = signal.detrend(inplace=False)
        if pad:
            phase.pad(
                how_much=pad, in_seconds=True,
                padding_type="reflect", side="both"
            )
        phase.filter(**filter_args)
        phase.hilbert_transform(return_as="phase_wrapped")
        if pad:
            phase.sel([phase.start_time + pad, phase.end_time - pad])
        return phase

    phase = get_phase(sig, filter_args={"low_freq": 0.5, "high_freq": 2})
    (node_mean_phases_down,
     node_mean_phases_up) = bp.get_transition_phases(states, phase.data)
    node_mean_phases_down = np.array(node_mean_phases_down)
    node_mean_phases_up = np.array(node_mean_phases_up)

    if np.any(node_mean_phases_up < 0) or np.any(node_mean_phases_down > 0):
        print("Modulo was necessary")
    node_mean_phases_up = np.mod(node_mean_phases_up, np.pi)
    node_mean_phases_down = np.mod(node_mean_phases_down, -np.pi)

    # mean_states = np.mean(states, axis=1)  # * 1000
    len_states = np.sum(states, axis=1) * model.params.dt / 1000

    normalized_down_lengths = model.params.duration / 1000 - len_states
    # to percent
    normalized_down_lengths = (normalized_down_lengths /
                               (model.params.duration / 1000) * 100)
    normalized_down_lengths = 1 - normalized_down_lengths/100

    # ----------------- Correlations ----------------- #
    columns = ['mean_degree', 'degree', 'closeness', 'betweenness',
               'mean_shortest_path', 'neighbor_degree', 'neighbor_degree_new',
               'clustering_coefficient', 'omega', 'sigma',
               'mean_clustering_coefficient', 'backbone', 'Cmat', 'Dmat']
    subset_cols = ['degree', 'closeness', 'betweenness', 'neighbor_degree_new',
                   'clustering_coefficient']
    df = pd.DataFrame(columns=columns)

    G = make_graph(cmat)
    G, gm = graph_measures(G)  # , dmat
    df.loc[0] = gm

    results_sc = pd.DataFrame(df.loc[0, subset_cols].to_dict())
    results_sc.columns = ['degree_sc', 'closeness_sc', 'betweenness_sc',
                          'neighbor_sc', 'clustering_sc']
    df = pd.DataFrame(columns=columns)

    fc = func.fc(model.BOLD.BOLD[:, 10:])
    fc[fc < 0] = 0
    G = make_graph(fc)
    G, gm = graph_measures(G)  # , dmat
    df.loc[0] = gm
    results_fc = pd.DataFrame(df.loc[0, subset_cols].to_dict())
    results_fc.columns = ['degree_fc', 'closeness_fc', 'betweenness_fc',
                          'neighbor_fc', 'clustering_fc']

    results = pd.concat([results_sc, results_fc], axis=1)
    for i, freq in enumerate(all_freq_bands):
        G = make_graph(con[:, :, i])
        G, gm = graph_measures(G)  # , dmat
        df = pd.DataFrame(columns=columns)
        df.loc[0] = gm
        results_plv = pd.DataFrame(df.loc[0, ['degree']].to_dict())
        results_plv.columns = [f'degree_plv_{freq}']
        results = pd.concat([results, results_plv], axis=1)

    G = make_graph(dtw_norm)
    G, gm = graph_measures(G)  # , dmat
    df = pd.DataFrame(columns=columns)
    df.loc[0] = gm
    results_dtw = pd.DataFrame(df.loc[0, ['degree']].to_dict())
    results_dtw.columns = ['degree_dtw']
    results = pd.concat([results, results_dtw], axis=1)

    if not isinstance(fc_real, bool):
        G = make_graph(fc_real)
        G, gm = graph_measures(G)  # , dmat
        df = pd.DataFrame(columns=columns)
        df.loc[0] = gm
        results_fc_real = pd.DataFrame(df.loc[0, subset_cols].to_dict())
        results_fc_real.columns = ['degree_fc_real', 'closeness_fc_real',
                                   'betweenness_fc_real', 'neighbor_fc_real',
                                   'clustering_fc_real']
        results = pd.concat([results, results_fc_real], axis=1)

    results.loc[:, 'time_up'] = normalized_down_lengths
    results.loc[:, 'phases_up'] = node_mean_phases_up
    results.loc[:, 'phases_down'] = node_mean_phases_down

    corr = results.corr()
    corr[pd.isna(corr)] = 0
    mask = np.zeros_like(corr)
    mask[np.triu_indices_from(mask)] = True
    cmap = sns.diverging_palette(250, 10, as_cmap=True)
    with sns.axes_style("white"):
        f, ax = plt.subplots(figsize=(20, 18))
        ax = sns.heatmap(corr, cmap=cmap,  # mask=mask,
                         vmax=1., vmin=-1.,
                         square=True, annot=True)
        # plt.xticks(rotation=60)
    plt.tight_layout()
    plt.savefig(path / f"{fname}_correlations.png")
    plt.close()

    sns.clustermap(corr, cmap=cmap, vmax=1., vmin=-1.)
    plt.tight_layout()
    plt.savefig(path / f"{fname}_correlations_clustes.png")
    plt.close()
    return corr