示例#1
0
def concat_RL(R_img, L_img, rl_idx_pair, rl_sign_pair=None):
    """
    Given R and L ICA images and their component index pairs, concatenate images to
    create bilateral image using the index pairs. Sign flipping can be specified in rl_sign_pair.

    """
    # Make sure images have same number of components and indices are less than the n_components
    assert R_img.shape == L_img.shape
    n_components = R_img.shape[3]
    assert np.max(rl_idx_pair) < n_components
    n_rl_imgs = len(rl_idx_pair[0])
    assert n_rl_imgs == len(rl_idx_pair[1])
    if rl_sign_pair:
        assert n_rl_imgs == len(rl_sign_pair[0])
        assert n_rl_imgs == len(rl_sign_pair[1])

    # Match indice pairs and combine
    terms = R_img.terms.keys()
    rl_imgs = []
    rl_term_vals = []

    for i in range(n_rl_imgs):
        rci, lci = rl_idx_pair[0][i], rl_idx_pair[1][i]
        R_comp_img = index_img(R_img, rci)
        L_comp_img = index_img(L_img, lci)

        # sign flipping
        r_sign = rl_sign_pair[0][i] if rl_sign_pair else 1
        l_sign = rl_sign_pair[1][i] if rl_sign_pair else 1

        R_comp_img = math_img("%d*img" % (r_sign), img=R_comp_img)
        L_comp_img = math_img("%d*img" % (l_sign), img=L_comp_img)

        # combine images
        rl_imgs.append(math_img("r+l", r=R_comp_img, l=L_comp_img))

        # combine terms
        if terms:
            r_ic_terms, r_ic_term_vals = get_ic_terms(R_img.terms, rci, sign=r_sign)
            l_ic_terms, l_ic_term_vals = get_ic_terms(L_img.terms, lci, sign=l_sign)
            rl_term_vals.append((r_ic_term_vals + l_ic_term_vals) / 2)

    # Squash into single image
    concat_img = nib.concat_images(rl_imgs)
    if terms:
        concat_img.terms = dict(zip(terms, np.asarray(rl_term_vals).T))
    return concat_img
示例#2
0
def plot_term_comparisons(terms, labels, ic_idx_list, sign_list, color_list=('g', 'r', 'b'),
                          top_n=4, bottom_n=4, standardize=True, out_dir=None):
    """
    Take the list of ica image terms and the indices of components to be compared, and
    plots the top n and bottom n term values for each component as a radar graph.

    The sign_list should indicate whether term values should be flipped (-1) or not (1).
    """
    assert len(terms) == len(labels)
    assert len(terms) == len(ic_idx_list)
    assert len(terms) == len(sign_list)
    assert len(terms) == len(color_list)
    n_comp = len(ic_idx_list[0])   # length of each ic_idx_list and sign_list
    for i in range(len(terms)):
        assert len(ic_idx_list[i]) == n_comp
        assert len(sign_list[i]) == n_comp

    # iterate over the ic_idx_list and sign_list for each term and plot
    # store top n and bottom n terms for each label
    term_arr = np.empty((len(labels), n_comp, top_n + bottom_n), dtype="S30")
    for n in range(n_comp):

        terms_of_interest = []
        term_vals = []
        name = ''

        for i, (term, label) in enumerate(zip(terms, labels)):
            idx = ic_idx_list[i][n]
            sign = sign_list[i][n]
            # Get list of top n and bottom n terms for each term list
            top_terms = get_n_terms(
                term, idx, n_terms=top_n, top_bottom='top', sign=sign)
            bottom_terms = get_n_terms(
                term, idx, n_terms=bottom_n, top_bottom='bottom', sign=sign)
            combined = np.append(top_terms, bottom_terms)
            terms_of_interest.append(combined)
            term_arr[i][n] = combined

            # Also store term vals (z-score if standardize) for each list
            t, vals = get_ic_terms(term, idx, sign=sign, standardize=standardize)
            s = pd.Series(vals, index=t, name=label)
            term_vals.append(s)

            # Construct name for the comparison
            name += label + '[%d] ' % (idx)

        # Data for all the terms
        termscore_df = pd.concat(term_vals, axis=1)

        # Get unique terms from terms_of_interest list
        toi_unique = np.unique(terms_of_interest)

        # Get values for unique terms_of_interest
        data = termscore_df.loc[toi_unique]
        data = data.sort_values(list(labels), ascending=False)

        # Now plot radar!
        N = len(toi_unique)
        theta = radar_factory(N)
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(1, 1, 1, projection='radar')
        title = "Term comparisons for %scomponents" % (name)
        ax.set_title(title, weight='bold', size='medium', position=(0.5, 1.1),
                     horizontalalignment='center', verticalalignment='center')

        y_min, y_max, y_tick = nice_bounds(data.values.min(), data.values.max())
        ax.set_ylim(y_min, y_max)
        ax.set_yticks([0], minor=True)
        ax.set_yticks(y_tick)
        ax.yaxis.grid(which='major', linestyle=':')
        ax.yaxis.grid(which='minor', linestyle='-')

        for label, color in zip(labels, color_list):
            ax.plot(theta, data[label], color=color)
            ax.fill(theta, data[label], facecolor=color, alpha=0.25)
        ax.set_varlabels(data.index.values)

        legend = plt.legend(labels, loc=(1.1, 0.9), labelspacing=0.1)
        plt.setp(legend.get_texts(), fontsize='small')
#        plt.show()

        # Saving
        if out_dir is not None:
            save_and_close(
                out_path=op.join(out_dir, '%sterm_comparisons.png' % (
                    name.replace(" ", "_"))))

    # Save top n and bottom n terms for each label
    term_dfs = []
    col_names = ["top%d" % (n + 1) for n in range(top_n)] + ["bottom%d" % (n + 1) for n in range(bottom_n)]
    for i, label in enumerate(labels):
        term_df = pd.DataFrame(term_arr[i], columns=["%s_%s" % (label, col) for col in col_names])
        term_df.insert(0, "%s_idx" % (label), ic_idx_list[i])
        term_dfs.append(term_df)
    all_term_df = pd.concat(term_dfs, axis=1)
    out_file = op.join(out_dir, 'term_comparison_summary.csv')
    all_term_df.to_csv(out_file)