예제 #1
0
def calc_baseline_connectivity(ictals_clips, connectivity_template):
    baseline_mat1, baseline_mat2 = None, None
    for clip_fname in ictals_clips['baseline']:
        clip_name = utils.namebase(clip_fname)
        con_ictal_fname = connectivity_template.format(clip_name=clip_name)
        print('concatenate {}'.format(utils.namebase(clip_fname)))
        d = np.load(con_ictal_fname) # C, W, O = d['con_values'].shape
        con_values = connectivity.find_best_ord(d['con_values'], False)
        con_values2 = connectivity.find_best_ord(d['con_values2'], False)
        if baseline_mat1 is None:
            baseline_mat1 = con_values.copy()
            baseline_mat2 = con_values2.copy()
        else:
            baseline_mat1 = np.concatenate((baseline_mat1, con_values), axis=1)
            baseline_mat2 = np.concatenate((baseline_mat2, con_values2), axis=1)
    return baseline_mat1, baseline_mat2
예제 #2
0
파일: utils.py 프로젝트: bdthombre/mmvt
def norm_values(baseline_x,
                cond_x,
                divide_by_baseline_std,
                threshold,
                reduce_to_3d=False):
    # cond_x = find_best_ord(cond_x)
    # baseline_x = find_best_ord(baseline_x)
    # Nodes x Time x Orders
    baseline_mean = baseline_x.mean(axis=1, keepdims=True)
    baseline_std = cond_x.std(
        axis=1, keepdims=True) if divide_by_baseline_std else None

    if threshold > 0:
        mask_indices = np.where(np.max(np.abs(cond_x), axis=1) < threshold)
    if divide_by_baseline_std:
        cond_x = (cond_x - baseline_mean) / baseline_std
    else:
        cond_x = cond_x - baseline_mean
    if threshold > 0:
        cond_x[mask_indices[0]] = np.zeros(cond_x[mask_indices[0]].shape)
    if reduce_to_3d and cond_x.ndim == 4:
        from src.preproc import connectivity
        cond_x = connectivity.find_best_ord(cond_x)
    print('{:.4f} {:.4f}'.format(np.nanmin(cond_x), np.nanmax(cond_x)))
    return cond_x
예제 #3
0
def normalize_connectivity(subject,
                           ictals_clips,
                           modality,
                           divide_by_baseline_std,
                           threshold,
                           reduce_to_3d,
                           overwrite=False,
                           n_jobs=6):
    connectivity_template = connectivity.get_output_fname(
        subject, 'gc', modality, 'mean_flip', 'all_{}_func_rois')
    for clip_fname in ictals_clips:
        clip_name = utils.namebase(clip_fname)
        output_fname = '{}_zvals.npz'.format(
            connectivity_template.format(clip_name)[:-4])
        con_ictal_fname = connectivity_template.format(clip_name)
        con_baseline_fname = connectivity_template.format(
            '{}_baseline'.format(clip_name))
        if not op.isfile(con_ictal_fname) or not op.isfile(con_baseline_fname):
            for fname in [
                    f for f in [con_ictal_fname, con_baseline_fname]
                    if not op.isfile(f)
            ]:
                print('{} is missing!'.format(fname))
            continue
        print('normalize_connectivity: {}:'.format(clip_name))
        d_ictal = utils.Bag(np.load(con_ictal_fname, allow_pickle=True))
        d_baseline = utils.Bag(np.load(con_baseline_fname, allow_pickle=True))
        if reduce_to_3d:
            d_ictal.con_values = connectivity.find_best_ord(
                d_ictal.con_values, False)
            d_ictal.con_values2 = connectivity.find_best_ord(
                d_ictal.con_values2, False)
            d_baseline.con_values = connectivity.find_best_ord(
                d_baseline.con_values, False)
            d_baseline.con_values2 = connectivity.find_best_ord(
                d_baseline.con_values2, False)
        d_ictal.con_values = epi_utils.norm_values(d_baseline.con_values,
                                                   d_ictal.con_values,
                                                   divide_by_baseline_std,
                                                   threshold, True)
        if 'con_values2' in d_baseline:
            d_ictal.con_values2 = epi_utils.norm_values(
                d_baseline.con_values2, d_ictal.con_values2,
                divide_by_baseline_std, threshold, True)
        print('Saving norm connectivity in {}'.format(output_fname))
        np.savez(output_fname, **d_ictal)
예제 #4
0
파일: plots.py 프로젝트: keshava/mmvt
def calc_cond_and_basline(subject,
                          con_method,
                          modality,
                          condition,
                          extract_mode,
                          band_name,
                          con_indentifer,
                          use_zvals,
                          node_names,
                          nodes_names_includes_hemi=False,
                          use_abs=True,
                          threshold=0.7,
                          window_length=25,
                          stc_downsample=2,
                          cond_name='interictals',
                          stc_subfolder='zvals',
                          stc_name=''):
    import mne
    from src.preproc import connectivity

    input_fname, baseline_fname = get_cond_and_baseline_fnames(
        subject, con_method, modality, condition, extract_mode, band_name,
        con_indentifer, use_zvals, cond_name)
    if not op.isfile(input_fname) or not op.isfile(baseline_fname):
        # print('Can\'t find {}'.format(input_fname))
        return None, None, None, None, None, None, None

    stcs_fol = op.join(MMVT_DIR, subject, 'meg', stc_subfolder)
    if stc_name == '':
        stc_name = '{}-epilepsy-dSPM-meg-{}-average-amplitude-zvals-rh.stc'.format(
            subject, condition)
    stc_fname = op.join(stcs_fol, stc_name)
    if op.isfile(stc_fname):
        stc = mne.read_source_estimate(stc_fname)
        times = utils.downsample(stc.times, stc_downsample)  # [window_length:]
        stc_data = np.max(stc.data, axis=0)
        stc_data = utils.downsample(stc_data,
                                    stc_downsample)  # [window_length:]
    else:
        stc_data, times = None, None

    d_cond, d_baseline = np.load(input_fname), np.load(baseline_fname)

    con_values1, con_values2 = fix_con_values(d_cond)
    con_values1, best_ords1 = connectivity.find_best_ord(con_values1,
                                                         return_ords=True)
    con_values2, best_ords2 = connectivity.find_best_ord(con_values2,
                                                         return_ords=True)
    # baseline_values1 = epi_utils.set_new_ords(d_baseline['con_values'], best_ords1)
    # baseline_values2 = epi_utils.set_new_ords(d_baseline['con_values2'], best_ords2)
    baseline_values1, baseline_values2 = fix_con_values(d_baseline)
    baseline_values1 = connectivity.find_best_ord(baseline_values1,
                                                  return_ords=False)
    baseline_values2 = connectivity.find_best_ord(baseline_values2,
                                                  return_ords=False)

    mask1 = epi_utils.filter_connections(
        con_values1,
        d_cond['con_names'],
        threshold,
        node_names,
        '',
        use_abs,
        nodes_names_includes_hemi=nodes_names_includes_hemi)
    mask2 = epi_utils.filter_connections(
        con_values2,
        d_cond['con_names2'],
        threshold,
        node_names,
        '',
        use_abs,
        nodes_names_includes_hemi=nodes_names_includes_hemi)
    names = np.concatenate(
        (d_cond['con_names'][mask1], d_cond['con_names2'][mask2]))
    if len(names) == 0:
        print('{} no connections'.format(condition))
        return None, None, None, None, None, None, None

    x_cond = np.concatenate((con_values1[mask1], con_values2[mask2]))
    x_baseline = np.concatenate(
        (baseline_values1[mask1], baseline_values2[mask2]))
    if best_ords1 is not None and best_ords2 is not None:
        best_ords = np.concatenate((best_ords1[mask1], best_ords2[mask2]))
        names = [
            '{} {}'.format(name, int(best_ord))
            for name, best_ord in zip(names, best_ords)
        ]
    return d_cond, d_baseline, x_cond, x_baseline, names, stc_data, times
예제 #5
0
def normalize_connectivity(subject, ictals_clips, modality, atlas, divide_by_baseline_std, threshold,
                           reduce_to_3d, time_axis=None, overwrite=False, n_jobs=6):
    # https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.stats.ttest_1samp.html
    import scipy.stats # t, p = scipy.stats.ttest_1samp
    import matplotlib.pyplot as plt
    calc_method = 'baseline_correction'
    top_k=5
    include = None #('superiorfrontal', 'parstriangularis', 'rostralmiddlefrontal') #, 'insula')
    baseline_con_fname = op.join(MMVT_DIR, subject, 'connectivity', '{}_baseline_{}_gc.npz'.format(modality, atlas))
    connectivity_template = op.join(MMVT_DIR, subject, 'connectivity', '{}_all_{}_{}_gc.npz'.format(
        modality, '{clip_name}', atlas))
    figures_fol = utils.make_dir(op.join(MMVT_DIR, subject, 'figures', 'gc'))
    if not op.isfile(baseline_con_fname) or overwrite:
        baseline_con_values1, baseline_con_values2 = calc_baseline_connectivity(ictals_clips, connectivity_template)
        print('Saving baseline connectivity {}'.format(baseline_con_fname))
        np.savez(baseline_con_fname, con_values=baseline_con_values1, con_values2=baseline_con_values2)
    else:
        print('Loading baseline connectivity {}'.format(baseline_con_fname))
        d_baseline = np.load(baseline_con_fname)
        baseline_con_values1, baseline_con_values2 = d_baseline['con_values'], d_baseline['con_values2']

    for clip_fname in ictals_clips['ictal']:
        clip_name = utils.namebase(clip_fname)
        print('\n\nAnalyzing {}'.format(clip_name))
        output_fname = op.join(MMVT_DIR, subject, 'connectivity', '{}_{}_{}_sig_con.pkl'.format(
            modality, clip_name, atlas))
        if False: #op.isfile(output_fname) and not overwrite:
            sig_con1, sig_con2, names1, names2 = utils.load(output_fname)
        else:
            con_ictal_fname = connectivity_template.format(clip_name=clip_name)
            d_ictal = utils.Bag(np.load(con_ictal_fname, allow_pickle=True))
            con_values1 = connectivity.find_best_ord(d_ictal.con_values, False)
            con_values2 = connectivity.find_best_ord(d_ictal.con_values2, False)
            # names = np.concatenate((d_cond['con_names'][mask1], d_cond['con_names2'][mask2]))
            C, T = con_values1.shape
            sig_con1, sig_con2, names1, names2 = [[]] * T, [[]] * T, [[]] * T, [[]] * T
            for t in range(T):
                inds = np.where(con_values1[:, t] < con_values2[:, t])
                con_values1[inds, t] = 0
                inds2 = np.where(con_values2[:, t] < con_values1[:, t])
                con_values2[inds2, t] = 0

                if calc_method == 'ttest_1samp':
                    res1 = scipy.stats.ttest_1samp(baseline_con_values1, con_values1[:, t], axis=1)[0]
                    res2 = scipy.stats.ttest_1samp(baseline_con_values2, con_values2[:, t], axis=1)[0]
                elif calc_method == 'zvals':
                    res1 = (con_values1[:, t] - baseline_con_values1.mean(1)) / baseline_con_values1.std(1)
                    res2 = (con_values2[:, t] - baseline_con_values2.mean(1)) / baseline_con_values2.std(1)
                elif calc_method == 'baseline_correction':
                    res1 = (con_values1[:, t] - baseline_con_values1.mean(1))
                    res2 = (con_values2[:, t] - baseline_con_values2.mean(1))
                if include is None:
                    mask1 = np.where(res1 > sorted(res1)[-top_k])[0]
                    mask2 = np.where(res2 > sorted(res2)[-top_k])[0]
                else:
                    mask1 = np.where(res1 > 0)[0] # sorted(ttest_res1)[-top_k])[0]
                    mask2 = np.where(res2 > 0)[0] #sorted(ttest_res2)[-top_k])[0]
                sig_con1[t] = res1[mask1]
                sig_con2[t] = res2[mask2]
                names1[t] = d_ictal['con_names'][mask1]
                names2[t] = d_ictal['con_names'][mask2]
                # print('Time {}, x->y {} connections > {}, y->x {} connections > {}'.format(
                #     t, len(sig_con1[t]), p_val_threshold, len(sig_con2[t]), p_val_threshold))
            print('Saving results in {}'.format(output_fname))
            utils.save((sig_con1, sig_con2, names1, names2), output_fname)
        plots.plot_pvalues(clip_name, time_axis, sig_con1, sig_con2, names1, names2, include, figures_fol)
예제 #6
0
def plot_norm_data(d_cond,
                   d_baseline,
                   x_axis,
                   condition,
                   threshold,
                   node_name,
                   stc_data,
                   stc_times,
                   windows_len=100,
                   windows_shift=10,
                   ax=None):
    import matplotlib.pyplot as plt
    from src.preproc import connectivity
    # from src.mmvt_addon import colors_utils as cu

    norm1 = d_cond['con_values'] - d_baseline['con_values'].mean(1,
                                                                 keepdims=True)
    norm2 = d_cond['con_values2'] - d_baseline['con_values2'].mean(
        1, keepdims=True)
    norm1, best_ords1 = connectivity.find_best_ord(norm1, return_ords=True)
    norm2, best_ords2 = connectivity.find_best_ord(norm2, return_ords=True)
    norm = {}
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)

    from itertools import product
    conn_conditions = list(product(['within', 'between'], utils.HEMIS))

    # colors = cu.get_distinct_colors(4)
    colors = ['c', 'b', 'k', 'm']
    lines, labels = [], []
    # x_axis = x_axis [:-10]
    for conn_type, color in zip(conn_conditions, colors):
        mask1 = epi_utils.filter_connections(node_name,
                                             norm1,
                                             d_cond['con_names'],
                                             threshold,
                                             conn_type,
                                             use_abs=False)
        mask2 = epi_utils.filter_connections(node_name,
                                             norm2,
                                             d_cond['con_names2'],
                                             threshold,
                                             conn_type,
                                             use_abs=False)
        norm[conn_type] = np.concatenate(
            (norm1[mask1], norm2[mask2]))  #[:, :-10]
        names = np.concatenate(
            (d_cond['con_names'][mask1], d_cond['con_names2'][mask2]))
        if best_ords1 is not None and best_ords2 is not None:
            best_ords = np.concatenate((best_ords1[mask1], best_ords2[mask2]))
            names = [
                '{} {}'.format(name, int(best_ord))
                for name, best_ord in zip(names, best_ords)
            ]
        if len(names) == 0 or max(norm[conn_type].max(0)) < 0:
            print('{} no connections {}'.format(condition, conn_type))
        else:
            windows_num = norm[conn_type].shape[1]
            dt = (stc_times[-1] - stc_times[windows_len]) / windows_num
            print(
                'windows num: {} windows length: {:.2f}ms windows shift: {:2f}ms'
                .format(windows_num,
                        (stc_times[windows_len] - stc_times[0]) * 1000,
                        dt * 1000))
            time = np.arange(stc_times[windows_len], stc_times[-1], dt)
            marker = '+' if conn_type[0] == 'within' else 'x'
            l = ax.scatter(time, norm[conn_type].max(0),
                           color=color)  #, marker=marker)
            lines.append(l)
            conn_type = (conn_type[0],
                         'right') if conn_type[1] == 'rh' else (conn_type[0],
                                                                'left')
            labels.append(' '.join(conn_type) if conn_type[0] ==
                          'within' else '{} to {}'.format(*conn_type))

    if stc_data is not None:
        ax2 = ax.twinx()
        l = ax2.plot(stc_times[windows_len:], stc_data[windows_len:].T,
                     'y--')  # stc_data[:-100].T
        lines.append(l[0])
        labels.append('Source normalized activity')
        ax2.set_ylim([0.5, 4.5])
        # ax2.set_xlim([])
        ax2.set_yticks(range(1, 5))
        ax2.set_ylabel('Source z-values', fontsize=12)
    # ax.set_xticks(time)
    # xticklabels = ['{}-{}'.format(t, t + windows_shift) for t in time]
    # xticklabels[2] = '{}\nonset'.format(xticklabels[2])
    # ax.set_xticklabels(xticklabels, rotation=30)
    ax.set_ylabel('Causality: Interictals\n minus Baseline', fontsize=12)
    # ax.set_yticks([0, 0.5])
    ax.set_ylim([0, 0.7])
    # ax.axvline(x=x_axis[10], color='r', linestyle='--')
    plt.title('{} interictals cluster'.format('Right' if condition ==
                                              'R' else 'Left'))

    # labs = [*conn_conditions, 'Source normalized activity']
    # ax.legend([l1[conn_conditions[k]][0] for k in range(4)] + l2, labs, loc=0)
    # ax.legend([l1[conn_conditions[0]]] + [l1[conn_conditions[1]]] + l2, labs, loc=0)
    ax.legend(lines, labels, loc=0)
    if ax is None:
        plt.show()