Beispiel #1
0
def plot_norm_data(x_cond,
                   x_baseline,
                   con_names,
                   condition,
                   threshold,
                   nodes_names,
                   stc_data,
                   stc_times,
                   windows_len=100,
                   windows_shift=10,
                   figures_fol='',
                   ax=None,
                   nodes_names_includes_hemi=False):
    # con_norm = x_cond - x_baseline
    # con_norm = x_cond - x_cond[:, :200].mean(axis=1, keepdims=True)
    # baseline_std = np.std(x_baseline, axis=1, keepdims=True)
    # baseline_mean = np.mean(x_baseline, axis=1, keepdims=True)
    windows_num = x_cond.shape[1]
    dt = (stc_times[-1] - stc_times[windows_len]) / windows_num
    time = np.arange(stc_times[windows_len], stc_times[-1], dt)[:-1]
    t0, t1 = np.where(time > -0.1)[0][0], np.where(time > 1)[0][0]

    # baseline_mean = np.max(x_cond[:, :t0], axis=1, keepdims=True)
    # baseline_std = np.std(x_cond[:, :t0], axis=1, keepdims=True)

    # con_norm = (x_cond - baseline_mean)  / baseline_std
    con_norm = x_cond - x_baseline
    fig_fname = op.join(figures_fol, 'ictal-baseline',
                        '{}-connectivity-ictal-baseline.jpg'.format(condition))
    connection_fname = utils.change_fname_extension(fig_fname, 'pkl')

    norm = {}
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    conn_conditions = list(product(['within', 'between'], utils.HEMIS))
    colors = ['c', 'b', 'k', 'm']
    lines, labels = [], []
    no_ord_con_names = [con_name.split(' ')[0] for con_name in con_names]

    connections = []
    for conn_type, color in zip(conn_conditions, colors):
        mask = epi_utils.filter_connections(
            con_norm,
            no_ord_con_names,
            threshold,
            nodes_names,
            conn_type,
            use_abs=False,
            nodes_names_includes_hemi=nodes_names_includes_hemi)
        if sum(mask) == 0:
            print('{} no connections {}'.format(condition, conn_type))
            continue
        else:
            print('{}: {} connection for {} {}'.format(condition, sum(mask),
                                                       conn_type[0],
                                                       conn_type[1]))
        names = np.array(con_names)[mask]
        norm[conn_type] = con_norm[mask]
        # print('windows num: {} windows length: {:.2f}ms windows shift: {:2f}ms'.format(
        #     windows_num, (stc_times[windows_len] - stc_times[0]) * 1000, dt * 1000))
        marker = '+' if conn_type[0] == 'within' else 'x'
        label_title = ' '.join(
            conn_type) if conn_type[0] == 'within' else '{} to {}'.format(
                *conn_type)
        first = True
        for k in range(norm[conn_type].shape[0]):
            first_sec_max = norm[conn_type][k][t0:t1].max()
            if norm[conn_type][k][t0:t1].max() > 2:
                # if conn_type[0] == 'between':
                first_sec_max_t = norm[conn_type][k][t0:t1].argmax()
                connections.append((time[first_sec_max_t + t0], label_title,
                                    first_sec_max, names[k]))
                l = ax.scatter(time, norm[conn_type][k],
                               color=color)  #, marker=marker) # .max(0)
                if first:
                    lines.append(l)
                    labels.append(label_title)
                    first = False
        conn_type = (conn_type[0],
                     'right') if conn_type[1] == 'rh' else (conn_type[0],
                                                            'left')

    connections = sorted(connections)
    for con in connections:
        print(con)
    utils.save(connections, connection_fname)
    if stc_data is not None:
        ax2 = ax.twinx()
        l = ax2.plot(stc_times[windows_len:],
                     stc_data[windows_len:].T,
                     'y--',
                     alpha=0.2)  # stc_data[:-100].T
        lines.append(l[0])
        labels.append('Source normalized activity')
        # ax2.set_ylim([0.5, 4.5])
        # ax2.set_xlim([])
        # ax2.set_yticks(range(1, 5))
        ax2.set_ylabel('Source z-values', fontsize=12)
    # ax.set_xticks(time)
    # xticklabels = ['{}-{}'.format(t, t + windows_shift) for t in time]
    # xticklabels[2] = '{}\nonset'.format(xticklabels[2])
    # ax.set_xticklabels(xticklabels, rotation=30)
    ax.set_ylabel('Causality: Interictals\n minus Baseline', fontsize=12)
    # ax.set_yticks([0, 0.5])
    ax.set_ylim(bottom=0)  #, 0.7])
    # ax.axvline(x=x_axis[10], color='r', linestyle='--')
    plt.title('{} ictal-baseline ({} connections)'.format(
        condition, x_cond.shape[0]))

    # labs = [*conn_conditions, 'Source normalized activity']
    # ax.legend([l1[conn_conditions[k]][0] for k in range(4)] + l2, labs, loc=0)
    # ax.legend([l1[conn_conditions[0]]] + [l1[conn_conditions[1]]] + l2, labs, loc=0)
    ax.legend(lines, labels, loc='upper right')  #loc=0)
    plt.axvline(x=0, linestyle='--', color='k')
    # if ax is None:
    if figures_fol != '':
        plt.savefig(fig_fname, dpi=300)
        print('Figure was saved in {}'.format(fig_fname))
        plt.close()
    else:
        plt.show()
Beispiel #2
0
def calc_cond_and_basline(subject,
                          con_method,
                          modality,
                          condition,
                          extract_mode,
                          band_name,
                          con_indentifer,
                          use_zvals,
                          node_names,
                          nodes_names_includes_hemi=False,
                          use_abs=True,
                          threshold=0.7,
                          window_length=25,
                          stc_downsample=2,
                          cond_name='interictals',
                          stc_subfolder='zvals',
                          stc_name=''):
    import mne
    from src.preproc import connectivity

    input_fname, baseline_fname = get_cond_and_baseline_fnames(
        subject, con_method, modality, condition, extract_mode, band_name,
        con_indentifer, use_zvals, cond_name)
    if not op.isfile(input_fname) or not op.isfile(baseline_fname):
        # print('Can\'t find {}'.format(input_fname))
        return None, None, None, None, None, None, None

    stcs_fol = op.join(MMVT_DIR, subject, 'meg', stc_subfolder)
    if stc_name == '':
        stc_name = '{}-epilepsy-dSPM-meg-{}-average-amplitude-zvals-rh.stc'.format(
            subject, condition)
    stc_fname = op.join(stcs_fol, stc_name)
    if op.isfile(stc_fname):
        stc = mne.read_source_estimate(stc_fname)
        times = utils.downsample(stc.times, stc_downsample)  # [window_length:]
        stc_data = np.max(stc.data, axis=0)
        stc_data = utils.downsample(stc_data,
                                    stc_downsample)  # [window_length:]
    else:
        stc_data, times = None, None

    d_cond, d_baseline = np.load(input_fname), np.load(baseline_fname)

    con_values1, con_values2 = fix_con_values(d_cond)
    con_values1, best_ords1 = connectivity.find_best_ord(con_values1,
                                                         return_ords=True)
    con_values2, best_ords2 = connectivity.find_best_ord(con_values2,
                                                         return_ords=True)
    # baseline_values1 = epi_utils.set_new_ords(d_baseline['con_values'], best_ords1)
    # baseline_values2 = epi_utils.set_new_ords(d_baseline['con_values2'], best_ords2)
    baseline_values1, baseline_values2 = fix_con_values(d_baseline)
    baseline_values1 = connectivity.find_best_ord(baseline_values1,
                                                  return_ords=False)
    baseline_values2 = connectivity.find_best_ord(baseline_values2,
                                                  return_ords=False)

    mask1 = epi_utils.filter_connections(
        con_values1,
        d_cond['con_names'],
        threshold,
        node_names,
        '',
        use_abs,
        nodes_names_includes_hemi=nodes_names_includes_hemi)
    mask2 = epi_utils.filter_connections(
        con_values2,
        d_cond['con_names2'],
        threshold,
        node_names,
        '',
        use_abs,
        nodes_names_includes_hemi=nodes_names_includes_hemi)
    names = np.concatenate(
        (d_cond['con_names'][mask1], d_cond['con_names2'][mask2]))
    if len(names) == 0:
        print('{} no connections'.format(condition))
        return None, None, None, None, None, None, None

    x_cond = np.concatenate((con_values1[mask1], con_values2[mask2]))
    x_baseline = np.concatenate(
        (baseline_values1[mask1], baseline_values2[mask2]))
    if best_ords1 is not None and best_ords2 is not None:
        best_ords = np.concatenate((best_ords1[mask1], best_ords2[mask2]))
        names = [
            '{} {}'.format(name, int(best_ord))
            for name, best_ord in zip(names, best_ords)
        ]
    return d_cond, d_baseline, x_cond, x_baseline, names, stc_data, times
Beispiel #3
0
def plot_norm_data(d_cond,
                   d_baseline,
                   x_axis,
                   condition,
                   threshold,
                   node_name,
                   stc_data,
                   stc_times,
                   windows_len=100,
                   windows_shift=10,
                   ax=None):
    import matplotlib.pyplot as plt
    from src.preproc import connectivity
    # from src.mmvt_addon import colors_utils as cu

    norm1 = d_cond['con_values'] - d_baseline['con_values'].mean(1,
                                                                 keepdims=True)
    norm2 = d_cond['con_values2'] - d_baseline['con_values2'].mean(
        1, keepdims=True)
    norm1, best_ords1 = connectivity.find_best_ord(norm1, return_ords=True)
    norm2, best_ords2 = connectivity.find_best_ord(norm2, return_ords=True)
    norm = {}
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)

    from itertools import product
    conn_conditions = list(product(['within', 'between'], utils.HEMIS))

    # colors = cu.get_distinct_colors(4)
    colors = ['c', 'b', 'k', 'm']
    lines, labels = [], []
    # x_axis = x_axis [:-10]
    for conn_type, color in zip(conn_conditions, colors):
        mask1 = epi_utils.filter_connections(node_name,
                                             norm1,
                                             d_cond['con_names'],
                                             threshold,
                                             conn_type,
                                             use_abs=False)
        mask2 = epi_utils.filter_connections(node_name,
                                             norm2,
                                             d_cond['con_names2'],
                                             threshold,
                                             conn_type,
                                             use_abs=False)
        norm[conn_type] = np.concatenate(
            (norm1[mask1], norm2[mask2]))  #[:, :-10]
        names = np.concatenate(
            (d_cond['con_names'][mask1], d_cond['con_names2'][mask2]))
        if best_ords1 is not None and best_ords2 is not None:
            best_ords = np.concatenate((best_ords1[mask1], best_ords2[mask2]))
            names = [
                '{} {}'.format(name, int(best_ord))
                for name, best_ord in zip(names, best_ords)
            ]
        if len(names) == 0 or max(norm[conn_type].max(0)) < 0:
            print('{} no connections {}'.format(condition, conn_type))
        else:
            windows_num = norm[conn_type].shape[1]
            dt = (stc_times[-1] - stc_times[windows_len]) / windows_num
            print(
                'windows num: {} windows length: {:.2f}ms windows shift: {:2f}ms'
                .format(windows_num,
                        (stc_times[windows_len] - stc_times[0]) * 1000,
                        dt * 1000))
            time = np.arange(stc_times[windows_len], stc_times[-1], dt)
            marker = '+' if conn_type[0] == 'within' else 'x'
            l = ax.scatter(time, norm[conn_type].max(0),
                           color=color)  #, marker=marker)
            lines.append(l)
            conn_type = (conn_type[0],
                         'right') if conn_type[1] == 'rh' else (conn_type[0],
                                                                'left')
            labels.append(' '.join(conn_type) if conn_type[0] ==
                          'within' else '{} to {}'.format(*conn_type))

    if stc_data is not None:
        ax2 = ax.twinx()
        l = ax2.plot(stc_times[windows_len:], stc_data[windows_len:].T,
                     'y--')  # stc_data[:-100].T
        lines.append(l[0])
        labels.append('Source normalized activity')
        ax2.set_ylim([0.5, 4.5])
        # ax2.set_xlim([])
        ax2.set_yticks(range(1, 5))
        ax2.set_ylabel('Source z-values', fontsize=12)
    # ax.set_xticks(time)
    # xticklabels = ['{}-{}'.format(t, t + windows_shift) for t in time]
    # xticklabels[2] = '{}\nonset'.format(xticklabels[2])
    # ax.set_xticklabels(xticklabels, rotation=30)
    ax.set_ylabel('Causality: Interictals\n minus Baseline', fontsize=12)
    # ax.set_yticks([0, 0.5])
    ax.set_ylim([0, 0.7])
    # ax.axvline(x=x_axis[10], color='r', linestyle='--')
    plt.title('{} interictals cluster'.format('Right' if condition ==
                                              'R' else 'Left'))

    # labs = [*conn_conditions, 'Source normalized activity']
    # ax.legend([l1[conn_conditions[k]][0] for k in range(4)] + l2, labs, loc=0)
    # ax.legend([l1[conn_conditions[0]]] + [l1[conn_conditions[1]]] + l2, labs, loc=0)
    ax.legend(lines, labels, loc=0)
    if ax is None:
        plt.show()