def mpl_palette(n_colors, variation='Set2'):  # or variation='colorblind'
    """Get any seaborn palette as a usable matplotlib colormap."""

    import seaborn as sb
    palette = sb.color_palette(variation, n_colors, desat=0.8)
    return (sb.blend_palette(palette, n_colors=n_colors, as_cmap=True),
            sb.blend_palette(palette, n_colors=n_colors))
Example #2
0
def make_plot(X_train, y_train, X, y, test_data, model, model_name, features, response):
    feature = X.columns
    f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharey=False)
    sns.regplot(X[feature[4]], y, test_data, ax=ax1)
    sns.boxplot(X[feature[4]], y, color="Blues_r", ax=ax2)
    model.fit(X_train, y_train)
    sns.residplot(X[feature[4]], (model.predict(X) - y) ** 2, color="indianred", lowess=True, ax=ax3)
    if model_name is 'linear':
        sns.interactplot(X[feature[3]], X[feature[4]], y, ax=ax4, filled=True, scatter_kws={"color": "dimgray"}, contour_kws={"alpha": .5})
    elif model_name is 'logistic':
        pal = sns.blend_palette(["#4169E1", "#DFAAEF", "#E16941"], as_cmap=True)
        levels = np.linspace(0, 1, 11)
        sns.interactplot(X[feature[3]], X[feature[4]], y, levels=levels, cmap=pal, logistic=True)
    else:
        pass
    ax1.set_title('Regression')
    ax2.set_title(feature[4]+' Value')
    ax3.set_title(feature[4]+' Residuals')
    ax4.set_title('Two-value Interaction')
    f.tight_layout()
    plt.savefig(model_name+'_'+feature[4], bbox_inches='tight')

    # Multi-variable correlation significance level
    f, ax = plt.subplots(figsize=(10, 10))
    cmap = sns.blend_palette(["#00008B", "#6A5ACD", "#F0F8FF",
                              "#FFE6F8", "#C71585", "#8B0000"], as_cmap=True)
    sns.corrplot(test_data, annot=False, diag_names=False, cmap=cmap)
    ax.grid(False)
    ax.set_title('Multi-variable correlation significance level')
    plt.savefig(model_name+'_multi-variable_correlation', bbox_inches='tight')

    # complete coefficient plot - believe this is only for linear regression
    sns.coefplot("diagnosis ~ "+' + '.join(features), test_data, intercept=True)
    plt.xticks(rotation='vertical')
    plt.savefig(model_name+'_coefficient_effects', bbox_inches='tight')
Example #3
0
def selectColormap(data_range, percent=0.1):
    """
    Determine whether to use a sequential or diverging color map, based on the
    defined data range to be used for the plot. Note the diverging color map
    only works when the data range spans zero (and it won't automatically put
    the neutral colour at zero).

    :param data_range: array-like containing either the minimum and maximum
                       levels of the data range, or the array of levels (e.g.
                       for contour maps).
    :param float percent: Threshold for switching from diverging to sequential
                          colormap

    :returns: A `matplotlib.colors.LinearSegmentedColormap` instance used for
              setting the `Figure.cmap` attribute.

    :raises: TypeError for non-numeric input, or if the `data_range` is not
             array-like.
    """
    if not isinstance(data_range, (list, np.ndarray, tuple)):
        raise TypeError("Data range must be a list or array of numeric values")

    if not isinstance(max(data_range), (int, float)):
        raise TypeError("Data range must be a list or array of numeric values")

    x = (abs(max(data_range)) - abs(min(data_range)))/ \
        (max(data_range) - min(data_range))
    if abs(x) < percent:
        palette = sns.color_palette("RdBu", 7)
        cmap = sns.blend_palette(palette, as_cmap=True)
    else:
        palette = sns.color_palette("YlOrRd", 7)
        cmap = sns.blend_palette(palette, as_cmap=True)

    return cmap
Example #4
0
def mpl_palette(n_colors, variation='Set2'):  # or variation='colorblind'
    """Get any seaborn palette as a usable matplotlib colormap."""

    import seaborn as sb
    palette = sb.color_palette(variation, n_colors, desat=0.8)
    return (sb.blend_palette(palette, n_colors=n_colors, as_cmap=True),
            sb.blend_palette(palette, n_colors=n_colors))
Example #5
0
def plot_confusion_matrix(cm,
                          ax=None,
                          figsize=(4, 4),
                          fs=12,
                          title=None,
                          norm_axis=1,
                          normalize=True):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=norm_axis)[:, np.newaxis]

        print("Acc = %.4f" % np.mean(np.diag(cm)))

    if ax is None:
        fig, ax = plt.subplots(1,
                               1,
                               constrained_layout=True,
                               figsize=figsize,
                               dpi=200)

    mask1 = np.eye(10) == 0

    mask2 = np.eye(10) == 1

    pal1 = sns.blend_palette(
        ["#f7f7f7", "#d1e5f0", "#92c5de", "#4393c3", "#2166ac", "#053061"],
        as_cmap=True)

    pal2 = sns.blend_palette(
        ["#f7f7f7", "#fddbc7", "#f4a582", "#d6604d", "#b2182b", "#67001f"],
        as_cmap=True)

    sns.heatmap(100 * cm,
                fmt=".1f",
                annot=False,
                cmap=pal1,
                linewidths=1,
                cbar=True,
                mask=mask1,
                ax=ax,
                linecolor="#ffffff")

    sns.heatmap(100 * cm,
                fmt=".1f",
                annot=False,
                cmap=pal2,
                linewidths=1,
                cbar=True,
                mask=mask2,
                ax=ax,
                linecolor="#ffffff")

    ax.set_ylabel('True label')

    ax.set_xlabel('Predicted label')

    if title is not None:
        ax.set_title(title, fontsize=fs)
Example #6
0
 def setUp(self):
     import seaborn as sns
     palette = [(1, 1, 1), (0.000, 0.627, 0.235), (0.412, 0.627, 0.235),
                (0.663, 0.780, 0.282), (0.957, 0.812, 0.000),
                (0.925, 0.643, 0.016), (0.835, 0.314, 0.118),
                (0.780, 0.086, 0.118)]
     div_pal = sns.color_palette("RdBu", 7)
     self.diverging_cmap = sns.blend_palette(div_pal, as_cmap=True)
     self.sequential_cmap = sns.blend_palette(palette, as_cmap=True)
Example #7
0
def selectColormap(data_range, percent=0.1):
    """
    Determine whether to use a sequential or diverging color map, based on the
    defined data range to be used for the plot. Note the diverging color map
    only works when the data range spans zero (and it won't automatically put
    the neutral colour at zero).
    
    Red to green colour palette using recommendations from  ISO22324 (2015).
    Unsaturated colour palette:
    [(0.486, 0.722, 0.573), (0.447, 0.843, 0.714), (0.875, 0.882, 0.443), 
    (0.969, 0.906, 0.514), (0.933, 0.729, 0.416), (0.906, 0.522, 0.373), 
    (0.937, 0.522, 0.616)]

    :param data_range: array-like containing either the minimum and maximum
                       levels of the data range, or the array of levels (e.g.
                       for contour maps).
    :param float percent: Threshold for switching from diverging to sequential
                          colormap

    :returns: A `matplotlib.colors.LinearSegmentedColormap` instance used for
              setting the `Figure.cmap` attribute.

    :raises: TypeError for non-numeric input, or if the `data_range` is not
             array-like.
    """
    if not isinstance(data_range, (list, np.ndarray, tuple)):
        raise TypeError("Data range must be a list or array of numeric values")

    if not isinstance(max(data_range), (int, float)):
        raise TypeError("Data range must be a list or array of numeric values")

    x = (abs(max(data_range)) - abs(min(data_range)))/ \
        (max(data_range) - min(data_range))
    if abs(x) < percent:
        palette = sns.color_palette("RdBu", 7)
        cmap = sns.blend_palette(palette, as_cmap=True)
    else:
        palette = [(1, 1, 1),
                   (0.000, 0.627, 0.235),
                   (0.412, 0.627, 0.235), 
                   (0.663, 0.780, 0.282),
                   (0.957, 0.812, 0.000),
                   (0.925, 0.643, 0.016),
                   (0.835, 0.314, 0.118),
                   (0.780, 0.086, 0.118)]
        cmap = sns.blend_palette(palette, as_cmap=True)

    return cmap
Example #8
0
def rain_colormap(subset=slice(None, None)):

    import seaborn as sns

    cmap = sns.blend_palette(
        [
            [0.988235, 0.988235, 0.992157],
            [0.811765, 0.831373, 0.886275],
            [0.627451, 0.678431, 0.788235],
            [0.521569, 0.615686, 0.729412],
            [0.584314, 0.698039, 0.749020],
            [0.690196, 0.803922, 0.772549],
            [0.847059, 0.905882, 0.796078],
            [1.000000, 0.980392, 0.756863],
            [0.996078, 0.839216, 0.447059],
            [0.996078, 0.670588, 0.286275],
            [0.992157, 0.501961, 0.219608],
            [0.968627, 0.270588, 0.152941],
            [0.835294, 0.070588, 0.125490],
            [0.674510, 0.000000, 0.149020],
            [0.509804, 0.000000, 0.149020],
        ][subset],
        n_colors=21,
        as_cmap=True,
    )

    return cmap
Example #9
0
    def plotPressureMeanDiff(self):
        """
        Plot a map of the difference between observed and synthetic mean
        pressure values.

        """

        datarange = (-25, 25)
        figure = ArrayMapFigure()

        map_kwargs = dict(llcrnrlon=self.lon_range.min(),
                          llcrnrlat=self.lat_range.min(),
                          urcrnrlon=self.lon_range.max(),
                          urcrnrlat=self.lat_range.max(),
                          projection='merc',
                          resolution='i')

        cbarlab = "Mean central pressure difference (hPa)"
        data = self.histMean - self.synMean
        xgrid, ygrid = np.meshgrid(self.lon_range, self.lat_range)
        figure.add(np.transpose(data), xgrid, ygrid, "Historical - Synthetic",
                   datarange, cbarlab, map_kwargs)
        figure.cmap = sns.blend_palette(sns.color_palette("coolwarm", 9),
                                        as_cmap=True)
        figure.plot()
        outputFile = pjoin(self.plotPath, 'meanPressureDiff.png')
        saveFigure(figure, outputFile)
Example #10
0
    def plot_means(self, save=False, ax=None, label_x=True):

        redgreen = lambda nc: sns.blend_palette(["#c0392b", "#27ae60"], n_colors=nc)
        sns.set(style='white', font_scale=1.5)
        if not hasattr(self, 'bold_mag'):
            self.make_bold_dfs()
        if ax is None:
            f, ax = plt.subplots(1, figsize=(5, 5))

        titl = describe_model(self.depends_on)
        df = self.bold_mag.copy()
        df.ix[(df.choice == 'go') & (df.cond <= 50), 'cond'] = 40
        df.ix[(df.choice == 'nogo') & (df.cond >= 50), 'cond'] = 60
        sns.barplot('cond', 'csum', data=df, order=np.sort(df.cond.unique()), palette=redgreen(6), ax=ax)

        mu = df.groupby(['choice', 'cond']).mean()['csum']
        ax.set_ylim(mu.min() * .55, mu.max() * 1.15)
        ax.set_xticklabels([])
        ax.set_xlabel('')
        if label_x:
            ax.set_xlabel('pGo', fontsize=24)
            ax.set_xticklabels(np.sort(df.cond.unique()), fontsize=22)
        ax.set_ylabel('$\Sigma \\theta_{G}$', fontsize=30)
        ax.set_yticklabels([])
        sns.despine()
        plt.tight_layout()
        if save:
            plt.savefig('_'.join([titl, self.decay, 'means.png']), dpi=300)
            plt.savefig(
                '_'.join([titl, self.decay, 'means.svg']), format='svg', rasterized=True)
        return ax
Example #11
0
def plot_rwp_bins(name, log_dir, ax=None, r_scale=True):
    """
    Plots clustering in different bins of sSFR
    """
    r, actual, pred, errs = util.get_wprp_bin_data(name, log_dir)
    if ax is None:
        fig = plt.figure(figsize=(12, 12))
        ax = plt.gca()
    colors = sns.blend_palette([red_col, blue_col], len(actual))
    for col, xi, var in zip(colors, actual, errs):
        if r_scale:
            plot_rwp(r, xi, var, ax, col)
        else:
            plot_wprp(r, xi, var, ax, col)
    for col, xi in zip(colors, pred):
        if r_scale:
            plot_rwp(r, xi, [], ax, col)
        else:
            plot_wprp(r, xi, [], ax, col)
    if r_scale:
        ax.set_ylabel('$r_p$ $w_p(r_p)$ $[Mpc$ $h^{-1}]$', fontsize=27)
    else:
        ax.set_ylabel('$w_p(r_p)$ $[Mpc$ $h^{-1}]$', fontsize=27)
    ax.set_xlabel('$r_p$ $[Mpc$ $h^{-1}]$', fontsize=27)
    ax.set_xscale('log')
    ax.set_xlim(9e-2, 30)
    return style_plots(ax)
Example #12
0
    def __init__(self, X, cells=None, P=None):

        import seaborn as sns
        self.plot_paramsP = dict(cmap=sns.blend_palette(['yellow', 'deeppink'],
                                                        as_cmap=True),
                                 zorder=5)

        self.X = X
        self.P = P
        self.cells = cells
        self.cell_idx = 0 if cells is not None else None
        self.cut = OrderedDict(zip(['row', 'col', 'depth'], [0, 0, 0]))

        if P is not None:
            self.P = 0 * self.X
            i, j, k = [(i - j + 1) // 2 for i, j in zip(self.X.shape, P.shape)]
            self.P[i:-i, j:-j, k:-k] = P

        fig = plt.figure(facecolor='w', figsize=(10, 10))
        gs = plt.GridSpec(4, 4)
        ax = dict()
        ax['depth'] = fig.add_subplot(gs[1:, :-1])
        ax['row'] = fig.add_subplot(gs[0, :3], sharex=ax['depth'])
        ax['col'] = fig.add_subplot(gs[1:, -1], sharey=ax['depth'])
        ax['3d'] = fig.add_subplot(gs[0, -1], projection='3d')

        self.fig, self.ax = fig, ax
        self.fig.canvas.mpl_connect('scroll_event', self.on_scroll)
        self.fig.canvas.mpl_connect('button_press_event', self.on_press)
        self.fig.canvas.mpl_connect('key_press_event', self.on_key)
        self.replot()
        plt.show()
Example #13
0
    def __init__(self):
        Figure.__init__(self)
        self.subfigures = []
        palette = sns.color_palette("YlOrRd", 7)
        self.cmap = sns.blend_palette(palette, as_cmap=True)

        self.canvas = FigureCanvas
Example #14
0
    def plot_NMF_ROIs(self, outdir='./'):
        sns.set_context('paper')
        theCM = sns.blend_palette(['lime', 'gold', 'deeppink'],
                                  n_colors=10)  # plt.cm.RdBu_r

        for key in (self.project() * SegmentMethod() * SpikeInference()
                    & dict(short_name='stm', method_name='nmf')).fetch.as_dict:
            mask_px, mask_w, ca, sp = (
                SegmentMask() * Trace() * Spikes()
                & key).fetch.order_by('mask_id')['mask_pixels', 'mask_weights',
                                                 'ca_trace', 'spike_trace']

            template = np.stack(
                [
                    normalize(
                        bugfix_reshape(t)[..., key['slice'] - 1].squeeze())
                    for t in (ScanCheck() & key).fetch['template']
                ],
                axis=2).mean(
                    axis=2
                )  # TODO: remove bugfix_reshape once djbug #191 is fixed

            d1, d2 = tuple(
                map(int, (ScanInfo() & key).fetch1['px_height', 'px_width']))
            masks = Segment.reshape_masks(mask_px, mask_w, d1, d2)
            gs = plt.GridSpec(6, 1)
            try:
                sh.mkdir(
                    '-p',
                    os.path.expanduser(outdir) +
                    '/scan_idx{scan_idx}/slice{slice}'.format(**key))
            except:
                pass
            for cell, (ca_trace, sp_trace) in enumerate(zip(ca, sp)):
                with sns.axes_style('white'):
                    fig = plt.figure(figsize=(6, 8))
                    ax_image = fig.add_subplot(gs[2:, :])

                with sns.axes_style('ticks'):
                    ax_ca = fig.add_subplot(gs[0, :])
                    ax_sp = fig.add_subplot(gs[1, :], sharex=ax_ca)
                ax_ca.plot(ca_trace, 'green', lw=1)
                ax_sp.plot(sp_trace, 'k', lw=1)

                ax_image.imshow(template, cmap=plt.cm.gray)
                ax_image.contour(masks[..., cell], colors=theCM, zorder=10)
                sns.despine(ax=ax_ca)
                sns.despine(ax=ax_sp)
                ax_ca.axis('tight')
                ax_sp.axis('tight')
                fig.suptitle(
                    "animal_id {animal_id}:session {session}:scan_idx {scan_idx}:{method_name}:slice{slice}:cell{cell}"
                    .format(cell=cell + 1, **key))
                fig.tight_layout()

                plt.savefig(
                    outdir +
                    "/scan_idx{scan_idx}/slice{slice}/cell{cell:03d}_animal_id_{animal_id}_session_{session}.png"
                    .format(cell=cell + 1, **key))
                plt.close(fig)
Example #15
0
 def plot_tile_qual(self):
     """Plot tile quality per base"""
     max_xaxis = self.seq_length + 1
     tile_qual = self.results_dict['flowcell_tile_qual_dict']
     sns.set(style="ticks", context="talk")
     plt.figure(figsize=(12, 8))
     cmap = sns.blend_palette(
         ('#ee0000', '#ecee00', '#00b61f', '#0004ff', '#0004ff'),
         n_colors=6,
         as_cmap=True,
         input='rgb')
     ax = sns.heatmap(pd.DataFrame(tile_qual).T,
                      cmap=cmap,
                      vmin=0,
                      vmax=42,
                      xticklabels=range(1, max_xaxis, 1))
     ax.set_xticks(range(1, max_xaxis, 2))
     ax.set_xticks(range(1, max_xaxis, 1), minor=True)
     ax.set_xticklabels(range(1, max_xaxis, 2))
     plt.yticks(rotation=0)
     plt.xlabel('Base position in read')
     plt.ylabel('Tile ID')
     plt.title('Average sequence quality per tile')
     plt.tight_layout()
     plt.savefig(os.path.join(self.outpath, 'tile_qual_plot.svg'))
Example #16
0
 def _init_palette(self):
     basis = sns.blend_palette(["seagreen", "ghostwhite", "#4168B7"],
                               self.number_colors)
     self.palette = [
         "rgb(%d, %d, %d)" % (r, g, b)
         for r, g, b, a in np.round(basis * 255)
     ]
Example #17
0
def visualize_correlations(data, annotate=False, fig_size=16):
    """
    Generates a correlation matrix heat map.

    Parameters
    ----------
    data : array-like
        Pandas data frame containing the entire data set.

    annotate : boolean, optional, default False
        Annotate the heat map with labels.

    fig_size : int, optional, default 20
        Size of the plot.
    """
    corr = data.corr()

    if annotate:
        corr = np.round(corr, 2)

    # generate a mask for the upper triangle
    mask = np.zeros_like(corr, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    fig, ax = plt.subplots(figsize=(fig_size, fig_size * 3 / 4))
    colormap = sb.blend_palette(sb.color_palette('coolwarm'), as_cmap=True)
    sb.heatmap(corr, mask=mask, cmap=colormap, annot=annotate)
    fig.tight_layout()
Example #18
0
def plot_correlations(df, stocks, factors):
    n_stocks = len(stocks)
    n_factors = len(factors)

    correlations = np.zeros((n_stocks, n_factors))

    for s_idx, stock in enumerate(stocks):
        for f_idx, factor in enumerate(factors):
            stock_df = df[(df['Stock'] == stock) & (df['Factor'] == factor)]
            if len(stock_df) > 0:
                normalize_price(stock_df)
                correlations[s_idx, f_idx] = stock_df['Sentiment'].corr(
                    stock_df['NormPrice'])

    # sns.heatmap(correlations, xticklabels=factors, yticklabels=stocks, linewidths=.5, cmap="RdYlGn", annot=True)
    sns.heatmap(correlations,
                xticklabels=factors,
                yticklabels=stocks,
                linewidths=.5,
                cmap=sns.blend_palette(['#6ce5e8', '#b1b3b3', '#EE959E'],
                                       n_colors=6,
                                       as_cmap=True,
                                       input='rgb'),
                annot=True)
    plt.yticks(rotation=0)
    plt.title("Sentiment correlation with stock price")
Example #19
0
def draw_seaborn_scatter(data, prediction):
    sns.set(style="darkgrid")

    p = sns.blend_palette(['#ff0000','#ff0000','#0000ff','#0000ff'], as_cmap=True)
    f, ax = pyplot.subplots(figsize=(6, 6))
    ax.set_aspect("equal")
    sns.scatterplot(x=data[:,0], y=data[:,1], hue=prediction[:,0], palette=p)
    pyplot.show()
def main():

    ncols = 841
    nrows = 681
    ndays = 366

    f = open("modis_climatology_splined.bin", "r")
    data = np.fromfile(f).reshape((ndays,nrows,ncols))
    f.close()

    ncolours = 7
    vmin = 0
    vmax = 6

    cmap = sns.blend_palette(["white", "#1b7837"], ncolours, as_cmap=True)
    sns.set(style="white")

    fig = plt.figure(figsize=(10, 6))
    grid = AxesGrid(fig, [0.05,0.05,0.9,0.9], nrows_ncols=(1,1), axes_pad=0.1,
    		        cbar_mode='single', cbar_pad=0.4, cbar_size="7%",
		            cbar_location='bottom', share_all=True)
	# 111.975 + (841. * 0.05)
    # -44.025 + (681. * 0.05)
    m = Basemap(projection='cyl', llcrnrlon=111.975, llcrnrlat=-44.025, \
                urcrnrlon=154.025, urcrnrlat=-9.974999999999994, resolution='h')


    ax = grid[0]
    m.ax = ax


    shp_info = m.readshapefile('/Users/mdekauwe/research/Drought_linkage/'
                               'Bios2_SWC_1979_2013/'
                               'AUS_shape/STE11aAust',
                               'STE11aAust', drawbounds=True)


    #m.drawrivers(linewidth=0.5, color='k')

    ax.set_xlim(140.5, 154)
    ax.set_ylim(-38, -28)

    #cmap = cmap_discretize(plt.cm.YlGnBu, ncolours)

    m.imshow(data[20,:,:], cmap,
             colors.Normalize(vmin=vmin, vmax=vmax, clip=True),
             origin='upper', interpolation='nearest')


    cbar = colorbar_index(cax=grid.cbar_axes[0], ncolours=ncolours, cmap=cmap,
                          orientation='horizontal', vmin=vmin, vmax=vmax)

    fig.savefig("/Users/mdekauwe/Desktop/LAI_NSW.png", bbox_inches='tight',
                pad_inches=0.1, dpi=300)



    plt.show()
def simulate_attractor_competition(Imax=12,
                                   I0=0.05,
                                   k=1.15,
                                   B=.6,
                                   g=15,
                                   b=30,
                                   rmax=100,
                                   si=6.5,
                                   dt=.002,
                                   tau=.075,
                                   Z=100,
                                   ntrials=250):

    sns.set(style='white', font_scale=1.8)
    f, ax = plt.subplots(1, figsize=(8, 7))
    cmap = mpl.colors.ListedColormap(
        sns.blend_palette([clrs[1], clrs[0]], n_colors=ntrials))
    Iscale = np.hstack(
        np.tile(np.linspace(.5 * Imax, Imax, ntrials / 2)[::-1], 2))
    Ivector = np.linspace(-1, 1, len(Iscale))
    norm = mpl.colors.Normalize(vmin=np.min(Ivector), vmax=np.max(Ivector))
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    for i, I_t in enumerate(Iscale):
        if i < (ntrials / 2.):
            I1 = Imax
            I2 = I_t
        else:
            I1 = I_t
            I2 = Imax
        r1, r2, dv, rt = attractor_network(I1=I1,
                                           I2=I2,
                                           I0=I0,
                                           k=k,
                                           B=B,
                                           g=g,
                                           b=b,
                                           rmax=rmax,
                                           si=si,
                                           dt=dt,
                                           tau=tau,
                                           Z=Z)
        ax.plot(r1, r2, color=sm.to_rgba(Ivector[i]), alpha=.5)

    c_ax = plt.colorbar(sm, ax=plt.gca())
    c_ax.set_ticks([-1, 1])
    c_ax.set_ticklabels(['$I_1<<I_2$', '$I_1>>I_2$'])
    ax.plot([0, rmax], [0, rmax], color='k', alpha=.5, linestyle='-', lw=3.5)
    _ = plt.setp(ax,
                 ylim=[0, rmax],
                 xlim=[0, rmax],
                 xticks=[0, rmax],
                 xticklabels=[0, rmax],
                 yticks=[0, rmax],
                 yticklabels=[0, rmax],
                 ylabel='$r_1$ (Hz)',
                 xlabel='$r_2$ (Hz)')
Example #22
0
    def animate_nearest_neighbour(self, **kwargs):
        # TODO: use custom jointplot
        bw_adjust = kwargs.setdefault("bw_adjust", 0.5)
        g = self.vis.show_nearest_neighbour(**kwargs)
        xlim = g.ax_joint.get_xlim()
        ylim = g.ax_joint.get_ylim()
        color = kwargs.pop("color", "Purple")
        color_rgb = colorConverter.to_rgb(color)
        colors = [
            sns.set_hls_values(color_rgb, l=l) for l in np.linspace(1, 0, 12)  # noqa
        ]
        # Make a colormap based off the plot color
        cmap = sns.blend_palette(colors, as_cmap=True)
        kwargs.setdefault("cmap", cmap)
        kwargs.setdefault("shade", True)
        marginal_kws = kwargs.pop("marginal_kws", dict())
        marginal_kws.update(bw_adjust=bw_adjust)
        marginal_kws.setdefault("color", color)
        marginal_kws.setdefault("shade", True)

        def init():
            g.ax_joint.clear()
            g.ax_marg_x.clear()
            g.ax_marg_y.clear()
            g.fig.suptitle(f"{0:>6.3f}")

        def animate(i):
            g.ax_joint.clear()
            g.ax_marg_x.clear()
            g.ax_marg_y.clear()
            x = self.sn.result.y[:, i]
            y = self.sn.get_nearest_neighbours(t=i)
            # drop nans
            not_na = pd.notnull(x) & pd.notnull(y)
            g.x = x[not_na]
            g.y = y[not_na]
            g.plot_joint(sns_kdeplot, **kwargs)
            g.plot_marginals(sns_kdeplot, **marginal_kws)

            # these are reset after .clear(); so go correct these as in joint_plot
            plt.setp(g.ax_marg_x.get_xticklabels(), visible=False)
            plt.setp(g.ax_marg_y.get_yticklabels(), visible=False)
            plt.setp(g.ax_marg_x.yaxis.get_majorticklines(), visible=False)
            plt.setp(g.ax_marg_x.yaxis.get_minorticklines(), visible=False)
            plt.setp(g.ax_marg_y.xaxis.get_majorticklines(), visible=False)
            plt.setp(g.ax_marg_y.xaxis.get_minorticklines(), visible=False)
            plt.setp(g.ax_marg_x.get_yticklabels(), visible=False)
            plt.setp(g.ax_marg_y.get_xticklabels(), visible=False)
            g.fig.suptitle(f"{self.sn.result.t[i]:>6.3f}", va="bottom")

        self.animations["nearest_neighbour"] = animation.FuncAnimation(
            g.fig,
            animate,
            init_func=init,
            frames=len(self.sn.result.t),
            repeat=False,
        )
Example #23
0
def show_heat(experiment, classes_dict, save=None):
    show_experiment(experiment)

    best = experiment['best_on_validation']

    entities, predicates = best['entities'], best['predicates']
    variable = 'Eemb'

    NE = len(entities)

    classes, elements = [], []
    for (_class, _elements) in classes_dict.items():
        for _element in _elements:
            classes += [_class]
            elements += [_element]

    indexes = [entities.index(element) for element in elements]
    class_idx = {
        _class: _idx for (_idx, _class) in enumerate(classes_dict.keys())}

    classes_numeric = [class_idx[_class] for _class in classes]

    logging.info('#Indexes: %d' % (len(indexes)))

    NI = len(indexes)

    label_palette = sns.color_palette("hls")

    sns.despine(trim=True)
    sns.set_context('poster')
    c_map = sns.blend_palette(["firebrick", "palegreen"], as_cmap=True)

    sns.set_style('white')

    parameter = best['parameters'][variable]

    embeds = numpy.asarray(parameter['value'])
    print("Embedding shape", embeds.shape)

    X = embeds.T

    Xr = X[numpy.asarray(indexes), :]

    sim_mat = Xr.dot(Xr.T)
    frame = pandas.DataFrame(sim_mat)

    heat = sns.heatmap(frame, linewidths=0, square=True, robust=True, xticklabels=False, yticklabels=False)

    for i, cl in enumerate(classes):
        if i > 0 and cl != classes[i - 1]:
            plt.axhline(len(classes) - i, c="w")
            plt.axvline(i, c="w")

    if save is None:
        plt.show(heat)
    else:
        plt.savefig(save)
def computeCorrelation(dataFrame, candidatesList,name):
    fig, axes = plt.subplots(figsize=(12,12))
    dfCorr = dataFrame[candidatesList]
    cmap = sb.blend_palette(["#6B229F", "#FD3232", "#F66433",
                          "#E78520", "#FFBB39"], as_cmap=True)
    sb.corrplot(dfCorr, annot=False, sig_stars=False,
             diag_names=False, cmap=cmap)
    axes.set_title("Correlation Matrix - " + name )
    plt.savefig('Correlation_'+candidatesList[0]+'_.png')
Example #25
0
def correlations(data, X):
    X_title = "_".join([i for i in X.columns.tolist()])
    f, ax = plt.subplots(figsize=(10, 10))
    cmap = sns.blend_palette(["#00008B", "#6A5ACD", "#F0F8FF",
                              "#FFE6F8", "#C71585", "#8B0000"], as_cmap=True)
    sns.corrplot(data, annot=False, diag_names=False, cmap=cmap)
    ax.grid(False)
    plt.savefig('visuals/'+X_title+'_correlation')
    print('visuals/'+X_title+'_correlation')
    plt.close()
Example #26
0
def save_adjastency_matrix(name: str, path: str, data):
    plt.title(name)
    fig = plt.figure(figsize=(80, 80))
    f, ax = plt.subplots(figsize=(11, 9))
    colors = ["blue", "yellow", "red"]
    cmap = sns.blend_palette(colors, as_cmap=True)
    sns.heatmap(data, cmap=cmap, square=True, xticklabels=100, yticklabels=100)
    plt.savefig(path)
    plt.close("all")
    plt.clf()
    return
def plot_pt_corr(df):
    """
    plot the correlation matrix of the posteriors of the parameters
    """

    f, ax = P.subplots(figsize=(9, 9))
    cmap = sns.blend_palette(["#00008B", "#6A5ACD", "#F0F8FF",
                              "#FFE6F8", "#C71585", "#8B0000"], as_cmap=True)
    sns.corrplot(df, annot=True, sig_stars=True, method='spearman',
                 diag_names=True, cmap=cmap, ax=ax)
    f.tight_layout()
Example #28
0
    def _parse_kwargs(self, **kwargs):
        """Parse the keyword arguments.
        """

        self._vmin = kwargs.get('vmin', None)
        self._vmax = kwargs.get('vmax', None)

        colors_blend = ['red', 'white', np.array([0, 1, 0, 1])]
        default_cmap = sns.blend_palette(colors_blend,
                                         n_colors=100,
                                         as_cmap=True)
        self._cmap = kwargs.get('cmap', default_cmap)

        self._dim = kwargs.get('dim', (2, 5))
        self._wspace = kwargs.get('wspace', 0.005)
        self._hspace = kwargs.get('hspace', 0.005)

        default_width_ratios = [
            self._dfc.shape[1], 0.25, self._dfs.shape[1],
            0.5 * self._dfs.shape[1], 0.05 * self._dfs.shape[1]
        ]

        default_height_ratios = [self._dfs.shape[1], self._dfc.shape[0]]

        self._width_ratios = kwargs.get('width_ratios', default_width_ratios)
        self._height_ratios = kwargs.get('height_ratios',
                                         default_height_ratios)

        default_position = {
            'condition': np.array([1, 0]),
            'heatmap': np.array([1, 2]),
            'row_dendrogram': np.array([1, 3]),
            'col_dendrogram': np.array([0, 2]),
            'colorbar': np.array([1, 4])
        }

        self._axes_position = kwargs.get('axes_position', default_position)

        self._row_cluster = kwargs.get('row_cluster', True)
        self._col_cluster = kwargs.get('col_cluster', True)

        if self._row_cluster:
            self._row_method = kwargs.get('row_method', 'single')
            self._row_metric = kwargs.get('row_metric', 'cityblock')
            self._row_dend_linewidth = kwargs.get('row_dend_linewidth', 0.5)

        if self._col_cluster:
            self._col_method = kwargs.get('col_method', 'single')
            self._col_metric = kwargs.get('col_metric', 'cityblock')
            self._col_dend_linewidth = kwargs.get('col_dend_linewidth', 0.5)

        self._table_linewidth = kwargs.get('table_linewidth', 0.5)
        self._table_tick_fontsize = kwargs.get('table_tick_fontsize', 5)
        self._colorbar_tick_fontsize = kwargs.get('colorbar_tick_fontsize', 5)
Example #29
0
def visualize_correlations(training_data):
    """
    Generates a correlation matrix heat map.
    """
    fig, ax = plt.subplots(figsize=(16, 10))
    colormap = sb.blend_palette(sb.color_palette('coolwarm'), as_cmap=True)
    if len(training_data.columns) < 30:
        sb.corrplot(training_data, annot=True, sig_stars=False, diag_names=True, cmap=colormap, ax=ax)
    else:
        sb.corrplot(training_data, annot=False, sig_stars=False, diag_names=False, cmap=colormap, ax=ax)
    fig.tight_layout()
Example #30
0
def format_confusion_matrix(matrix,
                            classes,
                            figsize=(8, 4.5),
                            title='',
                            **kwargs):
    """
    Plot confusion matrix using provided sklearn metrics.confusion_matrix values
    :param matrix: Numpy array with values from sklearn metrics.confusion_matrix
    :param classes: Tuple for (negative, positive) class labels
    :param figsize: Figure size
    :param title: Figure title
    :param kwargs: Arguments to pass to the plot function
    :return: figure with confusion matrix plot
    """
    tn, fp, fn, tp = matrix.ravel()
    labels = np.array([[
        'TP = {}'.format(tp), 'FP = {}'.format(fp),
        'Precision = {:.2f}%'.format(tp / (tp + fp) * 100)
    ], ['FN = {}'.format(fn), 'TN = {}'.format(tn), ''],
                       [
                           'TPR = {:.2f}%'.format(tp / (tp + fn) * 100),
                           'FPR = {:.2f}%'.format(fp / (fp + tn) * 100), ''
                       ]])

    #sns.set_style("ticks", {"xtick.major.size": 0, "ytick.major.size": 0})
    #sns.set(font_scale=1.2)

    columns = ['Labelled ' + classes[1], 'Labelled ' + classes[0], '']
    index = ['Predicted ' + classes[1], 'Predicted ' + classes[0], '']
    vals = np.array([[tp, fp, 0], [fn, tn, 0], [0, 0, 0]])
    template = pd.DataFrame(vals, index=index, columns=columns)
    print(template)

    vmax = np.sum(vals)
    cmap = sns.blend_palette(['white', '#0066cc'], as_cmap=True)
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    sns.heatmap(template,
                ax=ax,
                annot=labels,
                fmt='',
                vmax=vmax,
                cbar=False,
                cmap=cmap,
                linewidths=1,
                **kwargs)
    ax.xaxis.tick_top()
    plt.suptitle(title, fontsize=13)
    plt.yticks(rotation=0)
    ax.tick_params(labelsize=15)
    fig.tight_layout()
    fig.subplots_adjust(top=0.77)

    return fig
Example #31
0
def get_cpals(name='all', aslist=False, random=False):
    rpal = lambda nc: sns.blend_palette(['#e88379', '#de143d'], n_colors=nc)
    bpal = lambda nc: sns.blend_palette(['#81aedb', '#3572C6'], n_colors=nc)
    gpal = lambda nc: sns.blend_palette(['#65b88f', '#27ae60'], n_colors=nc)
    ppal = lambda nc: sns.blend_palette(['#9B59B6', "#663399"], n_colors=nc)
    heat = lambda nc: sns.blend_palette(['#f39c12', '#e5344a'], n_colors=nc)
    cool = lambda nc: sns.blend_palette(["#4168B7", "#27ae60"], n_colors=nc)
    slate = lambda nc: sns.blend_palette(['#95A5A6', "#6C7A89"], n_colors=nc)
    wet = lambda nc: sns.blend_palette(['#34495e', "#99A4AE"], n_colors=nc)
    fire = lambda nc: sns.blend_palette(['#e5344a', "#f39c12"], n_colors=nc)
    bupu = lambda nc: sns.blend_palette(['#8E44AD', "#3498db"], n_colors=nc)
    color_dict = {'bpal': bpal, 'gpal': gpal, 'rpal': rpal, 'ppal': ppal, 'heat': heat, 'cool': cool, 'slate': slate, 'wet': wet, 'fire':fire, 'bupu': bupu}
    if random:
        pals = listvalues(color_dict)
        i = randint(0, len(pals), 1)
        return pals[i]
    if name=='all':
        if aslist:
            return listvalues(color_dict)
        return color_dict
    else:
        return color_dict[name]
Example #32
0
def generate_continuous_palette(colors, n_colors: int = 256) -> List[str]:
    """Generates a continuous color palette out of a given sequence of colors.

    Args:
        colors (str): A sequence of Hex colors.
        n_colors (int, optional): Number of colors to include in the resulting palette. Defaults to 256.

    Returns:
        List[str]: A continuous color palette.
    """
    palette = sns.blend_palette([normalize_rgb(hex_to_rgb(c)) for c in colors],
                                n_colors=n_colors)
    return [rgb_to_hex(denormalize_rgb(color)) for color in palette]
Example #33
0
def param_color_map(param='all'):
    param_color_map = {'a': "#375ee1", 'tr': "#f19b2c", 'v': "#27ae60", 'xb': "#16a085", 'ssv': "#e5344a", 'ssv_v': "#3498db", 'sso': "#e941cd", 'z': '#ff711a', 'all': '#6C7A89', 'flat': '#6C7A89', 'BX': '#f19b2c', 'AX': '#3498db', 'Beta': "#ff711a", 'v_ssv': "#9B59B6", 'PX':"#e5344a"}
    if param=='all':
        return param_color_map
    if param in list(param_color_map):
        return param_color_map[param]
    elif '_' in param:
        params = param.split('_')
        blended = [param_color_map[p] for p in params]
        return sns.blend_palette(blended, n_colors=6)[3]
    elif param not in list(param_color_map):
        clrs = assorted_list()
        ix = np.random.randint(0, len(clrs))
        return clrs[ix]
 def __init__(self, large_df):
     self.f, self.ax = plt.subplots(figsize=(9, 9))
     plt.tight_layout()
     super(CorrelationMatrixFigure, self).__init__(self.f)
     palette = self.palette()
     self.f.set_facecolor(palette.background().color().getRgbF()[0:3])
     self.df = None
     self.corr = None
     self.cmap = sns.blend_palette(
         ["#00008B", "#6A5ACD", "#F0F8FF", "#FFE6F8", "#C71585", "#8B0000"],
         as_cmap=True)
     self.mpl_connect("motion_notify_event", self.get_tooltip_message)
     self.mpl_connect("button_press_event", self.square_clicked)
     self.large_df = large_df
     self.on_draw()
Example #35
0
def get_cpals(name='all'):
    rpal = lambda nc: sns.blend_palette(['#e88379', '#c0392b'], n_colors=nc)
    bpal = lambda nc: sns.blend_palette(['#81aedb', '#3A539B'], n_colors=nc)
    gpal = lambda nc: sns.blend_palette(['#65b88f', '#27ae60'], n_colors=nc)
    ppal = lambda nc: sns.blend_palette(['#848bb6', "#9B59B6"], n_colors=nc)
    heat = lambda nc: sns.blend_palette(['#f39c12', '#c0392b'], n_colors=nc)
    cool = lambda nc: sns.blend_palette(["#4168B7", "#27ae60"], n_colors=nc)
    slate = lambda nc: sns.blend_palette(['#95A5A6', "#6C7A89"], n_colors=nc)
    color_dict = {'rpal': rpal, 'bpal': bpal, 'gpal': gpal,
                  'ppal': ppal, 'heat': heat, 'cool': cool, 'slate': slate}
    if name == 'all':
        return color_dict
    else:
        return color_dict[name]
Example #36
0
def display_corr_matrix():
	'''
	function plots a correlation matrix heat map
	'''
	global DF

	### create a correlation matrix heatmap to look for colinearity
	data = DF
	sns.set(color_codes=True)
	f, ax = plt.subplots(figsize=(9, 9))
	cmap = sns.blend_palette(["#00008B", "#6A5ACD", "#F0F8FF",
	                          "#FFE6F8", "#C71585", "#8B0000"], as_cmap=True)
	sns.corrplot(data, annot=False, sig_stars=False,
	             diag_names=False, cmap=cmap, ax=ax)
	sns.plt.title('Figure 1: Correlation Matrix Heatmap')
	f.tight_layout()
	sns.despine()
	sns.plt.show()
Example #37
0
    def __convert_graph(self, g):
        from networkx.readwrite import json_graph
        new_value = json_graph.node_link_data(g)
        new_value['edges'] = new_value['links']
        new_value.pop('links', None)
        p = sns.blend_palette(["mediumseagreen", "ghostwhite", "#4168B7"], 9, as_cmap=True)
        for n in new_value['nodes']:
            n['x'] = np.random.random()
            n['y'] = np.random.random()
            n['color'] = rgb2hex( p( np.random.random()) )#.decode("ascii")
            n['size'] = 0.5
            n['id'] = "n%d"% n['id']

        for e in new_value["edges"]:
            e['id'] = 'e%d=%d'  % (e['source'], e['target'])
            e['source'] = "n%d"% e['source']
            e['target'] = "n%d"% e['target']

        return dict(filter(lambda x: x[0] in ['nodes', 'edges'], new_value.items()))
Example #38
0
def scatterPlot(transformed):
    sns.set(style="white")
    pal =  sns.blend_palette(vapeplot.palette('vaporwave'))
    
    ax = sns.lmplot(x="x", y="y",hue='cluster', data=transformed, legend=False,
                       fit_reg=False, height =8, scatter_kws={"s": 25}, palette=pal)
    
    texts = []
    for x, y, s in zip(transformed.x, transformed.y, transformed.pos):
        texts.append(plt.text(x, y, s))
    adjust_text(texts) #, arrowprops=dict(arrowstyle="->", color='r', lw=0.5))  # uncomment to add arrows to labels
    
    #ax._legend.set_title(prop={fontsize:'15'})
    ax.set(ylim=(-2, 2))
    plt.tick_params(labelsize=15)
    #plt.setp(ax.get_legend().get_title(), fontsize='15')
    plt.xlabel('PC1', fontsize=20)
    plt.ylabel("PC2", fontsize=20)
    plt.show()
Example #39
0
    def make_figure(self):
        import seaborn as sns
        import matplotlib.pyplot as plt

        key = self.fetch1()
        m, p = key["oracle_map"], key["p_map"]
        cmap = sns.blend_palette(
            ["dodgerblue", "steelblue", "k", "lime", "yellow"], as_cmap=True)
        with sns.axes_style("white"):
            title = "oracle image for {animal_id}-{session}-{scan_idx} field {field}".format(
                **key)
            fig = plt.figure(figsize=(15, 15))
            if m.shape[0] > m.shape[1]:
                orientation = "horizontal"
                gs = plt.GridSpec(21, 2)
                ax_corr = fig.add_subplot(gs[:-1, 0])
                cax_corr = fig.add_subplot(gs[-1, 0])
                ax_p = fig.add_subplot(gs[:-1, 1])
                cax_p = fig.add_subplot(gs[-1, 1])
            else:
                orientation = "vertical"
                gs = plt.GridSpec(2, 21)
                ax_corr = fig.add_subplot(gs[0, :-1])
                cax_corr = fig.add_subplot(gs[0, -1])
                ax_p = fig.add_subplot(gs[1, :-1])
                cax_p = fig.add_subplot(gs[1, -1])

            # v = np.abs(m).max()
            h = ax_corr.imshow(m, vmin=-1, vmax=1, cmap=cmap)
            fig.colorbar(h, cax=cax_corr, orientation=orientation)

            h = ax_p.matshow(np.log(p / p.size), cmap="coolwarm_r")
            # fig.colorbar(h, cax=cax_p, orientation=orientation)
            [a.axis("off") for a in [ax_corr, ax_p]]
            fig.tight_layout()
            fig.subplots_adjust(top=0.9)
            ax_corr.set_title("oracle correlation map")
            ax_p.set_title("log p-value (incorrect DF)")
            fig.suptitle(
                "{animal_id}-{session}-{scan_idx} field {field}".format(**key))
            plt.show()
            plt.close(fig)
Example #40
0
def visualize_correlations(training_data):
    """
    Generates a correlation matrix heat map.
    """
    fig, ax = plt.subplots(figsize=(16, 10))
    colormap = sb.blend_palette(sb.color_palette('coolwarm'), as_cmap=True)
    if len(training_data.columns) < 30:
        sb.corrplot(training_data,
                    annot=True,
                    sig_stars=False,
                    diag_names=True,
                    cmap=colormap,
                    ax=ax)
    else:
        sb.corrplot(training_data,
                    annot=False,
                    sig_stars=False,
                    diag_names=False,
                    cmap=colormap,
                    ax=ax)
    fig.tight_layout()
def _custom_palettes():
    return {
        'YellowOrangeBrown': 'YlOrBr',
        'YellowOrangeRed': 'YlOrRd',
        'OrangeRed': 'OrRd',
        'PurpleRed': 'PuRd',
        'RedPurple': 'RdPu',
        'BluePurple': 'BuPu',
        'GreenBlue': 'GnBu',
        'PurpleBlue': 'PuBu',
        'YellowGreen': 'YlGn',
        'summer': 'summer_r',
        'copper': 'copper_r',
        'viridis': 'viridis_r',
        'plasma': 'plasma_r',
        'inferno': 'inferno_r',
        'magma': 'magma_r',
        'sirocco': sns.cubehelix_palette(
            dark=0.15, light=0.95, as_cmap=True),
        'drifting': sns.cubehelix_palette(
            start=5, rot=0.4, hue=0.8, as_cmap=True),
        'melancholy': sns.cubehelix_palette(
            start=25, rot=0.4, hue=0.8, as_cmap=True),
        'enigma': sns.cubehelix_palette(
            start=2, rot=0.6, gamma=2.0, hue=0.7, dark=0.45, as_cmap=True),
        'eros': sns.cubehelix_palette(start=0, rot=0.4, gamma=2.0, hue=2,
                                      light=0.95, dark=0.5, as_cmap=True),
        'spectre': sns.cubehelix_palette(
            start=1.2, rot=0.4, gamma=2.0, hue=1, dark=0.4, as_cmap=True),
        'ambition': sns.cubehelix_palette(start=2, rot=0.9, gamma=3.0, hue=2,
                                          light=0.9, dark=0.5, as_cmap=True),
        'mysteriousstains': sns.light_palette(
            'baby shit green', input='xkcd', as_cmap=True),
        'daydream': sns.blend_palette(
            ['egg shell', 'dandelion'], input='xkcd', as_cmap=True),
        'solano': sns.blend_palette(
            ['pale gold', 'burnt umber'], input='xkcd', as_cmap=True),
        'navarro': sns.blend_palette(
            ['pale gold', 'sienna', 'pine green'], input='xkcd', as_cmap=True),
        'dandelions': sns.blend_palette(
            ['sage', 'dandelion'], input='xkcd', as_cmap=True),
        'deepblue': sns.blend_palette(
            ['really light blue', 'petrol'], input='xkcd', as_cmap=True),
        'verve': sns.cubehelix_palette(
            start=1.4, rot=0.8, gamma=2.0, hue=1.5, dark=0.4, as_cmap=True),
        'greyscale': sns.blend_palette(
            ['light grey', 'dark grey'], input='xkcd', as_cmap=True)}
Example #42
0
    def plot_NMF_ROIs(self, outdir='./'):
        sns.set_context('paper')
        theCM = sns.blend_palette(['lime', 'gold', 'deeppink'], n_colors=10)  # plt.cm.RdBu_r

        for key in (self.project() * SegmentMethod()*SpikeInference() & dict(short_name='stm', method_name='nmf')).fetch.as_dict:
            mask_px, mask_w, ca, sp = (SegmentMask()*Trace()*Spikes() & key).fetch.order_by('mask_id')['mask_pixels', 'mask_weights', 'ca_trace', 'spike_trace']

            template = np.stack([normalize(bugfix_reshape(t)[..., key['slice']-1].squeeze())
                                 for t in (ScanCheck() & key).fetch['template']], axis=2).mean(axis=2) # TODO: remove bugfix_reshape once djbug #191 is fixed

            d1, d2 = tuple(map(int, (ScanInfo() & key).fetch1['px_height', 'px_width']))
            masks = Segment.reshape_masks(mask_px, mask_w, d1, d2)
            gs = plt.GridSpec(6,1)
            try:
                sh.mkdir('-p', os.path.expanduser(outdir) + '/scan_idx{scan_idx}/slice{slice}'.format(**key))
            except:
                pass
            for cell, (ca_trace, sp_trace) in enumerate(zip(ca, sp)):
                with sns.axes_style('white'):
                    fig = plt.figure(figsize=(6,8))
                    ax_image = fig.add_subplot(gs[2:,:])

                with sns.axes_style('ticks'):
                    ax_ca = fig.add_subplot(gs[0,:])
                    ax_sp = fig.add_subplot(gs[1,:], sharex=ax_ca)
                ax_ca.plot(ca_trace,'green', lw=1)
                ax_sp.plot(sp_trace,'k',lw=1)

                ax_image.imshow(template, cmap=plt.cm.gray)
                ax_image.contour(masks[..., cell], colors=theCM, zorder=10    )
                sns.despine(ax=ax_ca)
                sns.despine(ax=ax_sp)
                ax_ca.axis('tight')
                ax_sp.axis('tight')
                fig.suptitle("animal_id {animal_id}:session {session}:scan_idx {scan_idx}:{method_name}:slice{slice}:cell{cell}".format(cell=cell+1, **key))
                fig.tight_layout()

                plt.savefig(outdir + "/scan_idx{scan_idx}/slice{slice}/cell{cell:03d}_animal_id_{animal_id}_session_{session}.png".format(cell=cell+1, **key))
                plt.close(fig)
    def __convert_graph(self, g):
        from networkx.readwrite import json_graph
        new_value = json_graph.node_link_data(g)
        new_value['edges'] = new_value['links']
        new_value.pop('links', None)
        p = sns.blend_palette(["mediumseagreen", "ghostwhite", "#4168B7"],
                              9,
                              as_cmap=True)
        for n in new_value['nodes']:
            n['x'] = np.random.random()
            n['y'] = np.random.random()
            n['color'] = rgb2hex(p(np.random.random()))  #.decode("ascii")
            n['size'] = 0.5
            n['id'] = "n%d" % n['id']

        for e in new_value["edges"]:
            e['id'] = 'e%d=%d' % (e['source'], e['target'])
            e['source'] = "n%d" % e['source']
            e['target'] = "n%d" % e['target']

        return dict(
            filter(lambda x: x[0] in ['nodes', 'edges'], new_value.items()))
Example #44
0
    def feature_correlations(self,
                             color_palette='coolwarm',
                             max_features=None,
                             annotate=False):
        """
        Generates a correlation matrix heat map.

        Parameters
        ----------
        color_palette : string, optional, default 'coolwarm'
            Seaborn color palette.

        max_features : int, optional, default None
            The maximum number of columns in the data to plot.

        annotate : boolean, optional, default False
            Annotate the heat map with labels.
        """
        self.print_message('Generating feature correlations plot...')

        if max_features:
            corr = self.data.iloc[:, :max_features].corr()
        else:
            corr = self.data.corr()

        if annotate:
            corr = np.round(corr, 2)

        mask = np.zeros_like(corr, dtype=np.bool)
        mask[np.triu_indices_from(mask)] = True

        fig, ax = plt.subplots(figsize=(self.fig_size, self.fig_size * 3 / 4))
        colormap = sb.blend_palette(sb.color_palette(color_palette),
                                    as_cmap=True)
        sb.heatmap(corr, mask=mask, cmap=colormap, annot=annotate)
        fig.tight_layout()

        self.print_message('Plot generation complete.')
Example #45
0
def plot_length_heamap(output_dir, database, interval):
    output_file = os.path.join(output_dir, "allele_length_heatmap.png")
    allele_info = get_allele_info(database)
    allele_info["intervals"] = list(
        map(lambda x: (int(x / interval) + 1) * interval,
            allele_info["length"]))
    pairs = db.from_sql("select * from pairs;", database=database)
    collect = []
    for locus_id, df in pairs.groupby("locus_id"):
        df2 = pd.merge(df, allele_info, on="allele_id", how="left")
        series = df2.groupby("intervals")["count"].sum()
        series.name = locus_id
        collect.append(series)
    table = pd.concat(collect, axis=1).fillna(0).T
    table = table.apply(lambda x: 100 * x / np.sum(x), axis=1)

    # sort by scheme order
    freq = db.from_sql("select locus_id from loci order by occurrence DESC;",
                       database=database)
    table = pd.merge(freq, table, left_on="locus_id",
                     right_index=True).set_index("locus_id")

    table = table.apply(mask_by_length, axis=1).apply(np.floor, axis=1)
    to_show = table.iloc[0:100, 0:80]

    # plot
    fig = plt.figure(figsize=(24, 16))
    ax = sns.heatmap(to_show,
                     annot=True,
                     annot_kws={},
                     mask=to_show.isnull(),
                     cmap=sns.blend_palette(["#446e8c", "#f6ff6d"],
                                            n_colors=20,
                                            as_cmap=True))
    fig.add_axes(ax)
    plt.xticks(rotation=30)
    plt.yticks(rotation=0)
    plt.savefig(output_file)
def display_grid_search_score_heatmaps():
    """
    Display 8 heatmaps for each of 8 values of alpha in the grid search.
    Each heatmap displays the 64 fitness scores for the simulations with the
    8 values each for gamma and epsilon. 
    """
    scored_results = score_grid_search_results()
    scores = [result['score'] for key, result in scored_results.iteritems()]
    max_score = max(scores)
    min_score = min(scores)
    ax = sns.plt.figure(figsize=(16, 25))
    plot_title = "Grid Search Fitness Scores\nMin: {}   |   Max: {}".format(round(min_score, 3), round(max_score, 3))
    sns.plt.suptitle(plot_title)
    cbar_ax = ax.add_axes([0.05, 0.92, 0.935, 0.02])
    for i in range(0,len(search_values)):
        alpha = search_values[i]
        score_grid = pd.DataFrame(columns=search_values)
        for gamma in reversed(search_values):
            gamma_scores = []
            for epsilon in search_values:
                key = "a:{},g:{},e:{}".format(alpha, gamma, epsilon)
                gamma_scores.append(scored_results[key]['score'])
            gamma_df = pd.DataFrame([gamma_scores], index=[gamma], columns=search_values)
            score_grid = score_grid.append(gamma_df)
        sns.plt.subplot(4,2,i+1)
        ax = sns.heatmap(score_grid,
                         annot=True,
                         cmap=sns.blend_palette(['#D43500','#FFFFFF', '#005C28'], as_cmap=True),
                         cbar=i == 0,
                         vmin=min_score, vmax=max_score,
                         cbar_ax=None if i else cbar_ax,
                         cbar_kws={"orientation": "horizontal"})
        ax.set(xlabel='Epsilon', ylabel='Gamma')
        ax.set_title('Alpha {}'.format(alpha))
    sns.plt.tight_layout()
    sns.plt.subplots_adjust(top=0.88)
    sns.plt.show()
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="darkgrid")

rs = np.random.RandomState(33)
d = rs.normal(size=(100, 30))

f, ax = plt.subplots(figsize=(9, 9))cmap = sns.blend_palette(["#00008B", "#6A5ACD", "#F0F8FF",    "#FFE6F8", "#C71585", "#8B0000"], as_cmap=True)

sns.corrplot(d, annot=False, sig_stars=False,  diag_names=False, cmap=cmap, ax=ax)

f.tight_layout()
# <codecell>

Produktbereiche.sum().to_csv('Haushaltsentwurf2015-Produktbereiche.csv', drop_index=True, float_format='%.2f')

# <headingcell level=2>

# Amtsbudgets

# <headingcell level=3>

# Barplots

# <codecell>

for i, ((Produktbereichsbezeichnung, Produktbezeichnung), group) in enumerate(Produktbereiche):
    colors = sns.blend_palette(["red", "mediumseagreen"], len(group['Ansatz']))

    Amtsa = group.sort('Ansatz')
    Amtsa.index = [s[:80] for s in Amtsa.Amtsbudget] # zu lange Namen shorten
    
    Amtsa['Ansatz'].plot(kind='barh', figsize=(10, 0.2*len(Amtsa)), label='', color=colors, title=Produktbezeichnung.decode('utf-8'))

    plt.tight_layout()
    plt.savefig('Amtsbudgets-%02d-%s.png' % (i, strcleaner(Produktbezeichnung)), transparent=True)
    plt.close()

# <headingcell level=3>

# Pie Plots

# <codecell>
Example #49
0
def _grey_color(color):
        cmap = sns.blend_palette([color, flatgrey], 6)
        return sns.color_palette(cmap)[4]
"""
Plotting a large correlation matrix
===================================

_thumb: .3, .6
"""
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="darkgrid")

rs = np.random.RandomState(33)
d = rs.normal(size=(100, 30))

f, ax = plt.subplots(figsize=(9, 9))
cmap = sns.blend_palette(["#00008B", "#6A5ACD", "#F0F8FF",
                          "#FFE6F8", "#C71585", "#8B0000"], as_cmap=True)
sns.corrplot(d, annot=False, sig_stars=False,
             diag_names=False, cmap=cmap, ax=ax)
f.tight_layout()
Example #51
0
def plot_wprp(actual_xis, actual_cov, pred_xis, pred_cov, set_desc, num_splits):
    """
    Plots calculated values of the correlation function and error bars
    as well as a secondary plot with a power fit for each group
    """
    n_groups = len(actual_xis)
    # create a range of colors from red to blue
    colors = sns.blend_palette([red_col, blue_col], n_groups)

    fig = plt.figure()
    ax = plt.gca()
    ax.set_xscale("log")
    ax.set_yscale('log')

    for i, xi_pred, cov_pred, xi_actual, cov_actual in \
            zip(xrange(n_groups), pred_xis, pred_cov, actual_xis, actual_cov):

        print str(i) + 'th bin'
        print 'chi square is:', chisquare(xi_pred, xi_actual)
        var1 = np.sqrt(np.diag(cov_pred))
        var2 = np.sqrt(np.diag(cov_actual))
        plt.errorbar(r, xi_actual, var2, fmt='-o', label=str(i+1), color=colors[i])
        plt.errorbar(r, xi_pred, var1, fmt='--o', color=colors[i], alpha=0.6)


    y_format = matplotlib.ticker.FuncFormatter(y_tick_formatter)
    ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
    ax.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
    ax.tick_params(pad=20)
    #plt.ticklabel_format(axis='y', style='plain')
    title = 'wp(rp) for ' + set_desc
    #plt.title(title)
    plt.xlabel('$r$ $[Mpc$ $h^{-1}]$')
    plt.xlim(1e-1, 30)
    plt.ylabel('$w_p(r_p)$')
    #plt.legend()

    # Fits power laws of the form c(x^-1.5) + y0

    plt.figure()
    plt.subplot(121)
    plt.hold(True)
    ax = plt.gca()
    ax.set_xscale("log")
    ax.set_yscale('log')
    fit_region, = np.where(r > 2)
    r_fit = r[fit_region]
    normalizations = []
    for i, xi_pred in zip(xrange(num_splits + 1), pred_xis):
        popt, pcov = curve_fit(fixed_power_law, r_fit, xi_pred[fit_region],
                               p0= [500,20])
        normalizations.append(popt[1])
        plt.plot(r_fit, fixed_power_law(r_fit, *popt), color=colors[i], label=str(i+1))
    plt.legend()

    plt.subplot(122)
    sns.barplot(np.arange(1,num_splits + 2), np.array(normalizations), palette=colors)

    plt.savefig(image_prefix + title + png)
    plt.show()

    return
Example #52
0
import argparse
from collections import OrderedDict

import matplotlib.pyplot as plt
import seaborn as sns

from aod_cells.schemata import *

# plot_params = dict(cmap=plt.cm.gray, vmin=0, vmax=1)
plot_params = dict(cmap=plt.cm.gray)
plot_paramsP = dict(cmap=sns.blend_palette(["yellow", "deeppink"], as_cmap=True), zorder=5)


class CellLabeler:
    def __init__(self, X, cells=None, P=None):
        self.X = X
        self.cells = cells
        self.cell_idx = 0 if cells is not None else None

        self.cut = OrderedDict(zip(["row", "col", "depth"], [0, 0, 0]))

        self.P = 0 * self.X
        if P is not None:
            i, j, k = [(i - j + 1) // 2 for i, j in zip(self.X.shape, P.shape)]
            self.P[i:-i, j:-j, k:-k] = P

        fig = plt.figure(facecolor="w")
        gs = plt.GridSpec(3, 5)
        ax = dict()
        ax["depth"] = fig.add_subplot(gs[1:3, :2])
        ax["row"] = fig.add_subplot(gs[0, :2], sharex=ax["depth"])
Example #53
0
def blend_palette(color1, color2, ncolors=10):
    import seaborn as sns
    blend = sns.blend_palette((color1, color2), ncolors)
    return [mpl.colors.rgb2hex(c) for c in blend]
Example #54
0
	mean3	IRRELEVANT	stderror3
	mean4	IRRELEVANT	stderror4
	mean4	IRRELEVANT	stderror5

Currently requires 5 primer pairs in normal orientation look down if different orientation
'''
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from coord import * #coord contains the coordinates of primers of used genes
from scipy.interpolate import spline
from seaborn import blend_palette,desaturate,color_palette
#global variables definitions 
primers=('5\'', '.','..','...','3\'')
color_scheme=blend_palette([desaturate("#009B76", 0), "#009B76"], 5)

line_color=color_palette("hls", 8)
os.chdir('/Users/Luis/Desktop')

'''global definitions from command line
	argument 1- should be file name, will be split at dot to extract GENE NAMe
	argument 2- should be a string with IP used
	argument 3- string with different conditions separated by white space
	'''
file_name=sys.argv[1]
Condition=sys.argv[2]
Conditions=sys.argv[3].split()
Gene=file_name.split('.')[0]

coordinates=eval(Gene) #uses the coord defined variables and the parsed Gene name
Example #55
0
  def make_violin(self):

    """
    Violin plots are made for the outliers over redshift. Each violin is a box plot, i.e.,
    it depicts the probability density of the outliers for a given bin in redshift.
    """

    from matplotlib.mlab import griddata
    import matplotlib.pyplot as plt
    import seaborn as sns

    self.logger.info("Generating violin plot...")
    ind = range(len(self.outliers))
    rows = list(set(np.random.choice(ind,10000)))
    self.logger.info("Using a smaller size for space ({0} objects)".format(self.reduce_size))

    outliers = self.outliers[rows]
    measured = self.measured[rows]
    predicted = self.predicted[rows]

    plt.figure()
    
    bins = np.arange(0,self.measured.max()+0.1,0.1)
    text_bins = ["{0}".format(i) for i in bins]

    digitized = np.digitize(measured, bins)

    outliers2 = (predicted - measured)/(measured+1)

    violins = [outliers2[digitized == i] for i in range(1, len(bins))]
    dbin = (bins[1]-bins[0])/2.
    bins += dbin

    final_violin, final_names = [], []

    for i in range(len(violins)):

      if len(violins[i]) > 1:
        final_violin.append(violins[i])
        final_names.append(bins[i])

    pal = sns.blend_palette([self.color_palette, "lightblue"], 4)

    sns.offset_spines()
    ax = sns.violinplot(final_violin, names=final_names, color=pal)
    sns.despine(trim=True)

    ax.set_ylabel(r"$(z_{\rm phot}-z_{\rm spec})/(1+z_{\rm spec})$", fontsize=self.fontsize)
    ax.set_xlabel(r"$z_{\rm spec}$", fontsize=self.fontsize)
    ax.set_ylim([-0.5,0.5])

    xtix = [i.get_text() for i in ax.get_xticklabels()]
    new_xtix = [xtix[i] if (i % 2 == 0) else "" for i in range(len(xtix))]
    ax.set_xticklabels(new_xtix)

    for item in ([ax.xaxis.label, ax.yaxis.label]):
            item.set_fontsize(self.fontsize)

    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
            item.set_fontsize(self.fontsize-10)

    ax.set_position([.15,.17,.75,.75])

    self.kde_ax = ax
    plt.savefig("PHOTZ_VIOLIN_{0}.pdf".format(self.family_name), format="pdf")
fig = plt.figure(figsize=(14, 10))
grid = AxesGrid(fig, [0.05,0.05,0.9,0.9], nrows_ncols=(1,1), axes_pad=0.1,
                cbar_mode='single', cbar_pad=0.2, cbar_size="3%",
                cbar_location='bottom', share_all=True)

m = Basemap(projection='cyl', llcrnrlon=-180.0, llcrnrlat=-90.0, \
            urcrnrlon=180, urcrnrlat=90.0, resolution='c')



# Range on colourbar
ncolours = 11
vmin = 0.0
vmax = 1.0

bmap = sns.blend_palette(["white", "darkgreen"], ncolours, as_cmap=True)
ax = grid[0]
m.ax = ax
m.drawcoastlines(linewidth=0.1, color='k')
m.drawcountries(linewidth=0.1, color='k')
image = m.imshow(np.flipud(data), bmap,
                 colors.Normalize(vmin=vmin, vmax=vmax, clip=True),
                 interpolation='nearest')
cbar = colorbar_index(cax=grid.cbar_axes[0], ncolours=ncolours, cmap=bmap,
                      orientation='horizontal')
cbar.set_ticklabels(np.linspace(vmin, vmax, ncolours))
cbar.set_label("C4 Fraction (-)", fontsize=16)

# fluxnet sites
x1, y1 = m(x, y)
m.scatter(x1, y1, marker="o", color="black", alpha=0.7)
Example #57
0
def visualize(training_data, X, y, pca):
    """
    Computes statistics describing the data and creates some visualizations
    that attempt to highlight the underlying structure.

    Note: Use '%matplotlib inline' and '%matplotlib qt' at the IPython console
    to switch between display modes.
    """

    print('Generating individual feature histograms...')
    num_features = len(training_data.columns)
    num_plots = num_features / 16 if num_features % 16 == 0 else num_features / 16 + 1
    for i in range(num_plots):
        fig, ax = plt.subplots(4, 4, figsize=(20, 10))
        for j in range(16):
            index = (i * 16) + j
            if index == 0:
                ax[j / 4, j % 4].hist(y, bins=30)
                ax[j / 4, j % 4].set_title(training_data.columns[index])
                ax[j / 4, j % 4].set_xlim((min(y), max(y)))
            elif index < num_features:
                ax[j / 4, j % 4].hist(X[:, index - 1], bins=30)
                ax[j / 4, j % 4].set_title(training_data.columns[index])
                ax[j / 4, j % 4].set_xlim((min(X[:, index - 1]), max(X[:, index - 1])))
        fig.tight_layout()

    print('Generating correlation matrix...')
    fig2, ax2 = plt.subplots(figsize=(16, 10))
    colormap = sb.blend_palette(["#00008B", "#6A5ACD", "#F0F8FF", "#FFE6F8", "#C71585", "#8B0000"], as_cmap=True)
    sb.corrplot(training_data, annot=False, sig_stars=False, diag_names=False, cmap=colormap, ax=ax2)
    fig2.tight_layout()

    if pca is not None:
        print('Generating principal component plots...')
        X = pca.transform(X)
        class_count = np.count_nonzero(np.unique(y))
        colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']

        fig3, ax3 = plt.subplots(figsize=(16, 10))
        for i in range(class_count):
            class_idx = i + 1  # add 1 if class labels start at 1 instead of 0
            ax3.scatter(X[y == class_idx, 0], X[y == class_idx, 1], c=colors[i], label=class_idx)
        ax3.set_title('First & Second Principal Components')
        ax3.legend()
        fig3.tight_layout()

        fig4, ax4 = plt.subplots(figsize=(16, 10))
        for i in range(class_count):
            class_idx = i + 1  # add 1 if class labels start at 1 instead of 0
            ax4.scatter(X[y == class_idx, 1], X[y == class_idx, 2], c=colors[i], label=class_idx)
        ax4.set_title('Second & Third Principal Components')
        ax4.legend()
        fig4.tight_layout()

        fig5, ax5 = plt.subplots(figsize=(16, 10))
        for i in range(class_count):
            class_idx = i + 1  # add 1 if class labels start at 1 instead of 0
            ax5.scatter(X[y == class_idx, 2], X[y == class_idx, 3], c=colors[i], label=class_idx)
        ax5.set_title('Third & Fourth Principal Components')
        ax5.legend()
        fig5.tight_layout()
Example #58
0
	def _init_palette(self):
		basis = sns.blend_palette(["seagreen", "ghostwhite", "#4168B7"], self.number_colors)
		self.palette = ["rgb(%d, %d, %d)" % (r,g,b) for r,g,b, a in np.round(basis * 255)]
Example #59
0
File: lindis.py Project: ctw/myhddm
def plot_correl(ax=None, figname="correlation_plot"):

	sns.set_style("white")
	sns.set_style("white", {"legend.scatterpoints": 1, "legend.frameon":Tru	
	#if ax:
	#    ax=sns.regplot(data.ix[:,0], data.ix[:,1], color='Red', scatter=True, ci=None, scatter_kws={'s':18}, ax=ax)
	#else:
	#    ax=sns.regplot(data.ix[:,0], data.ix[:,1], color='Blue', scatter=True, ci=None, scatter_kws={'s':1	
	dataf=pd.read_csv("FaceEigen_RT_keep.csv")
	datah=pd.read_csv("HouseEigen_RT_keep.csv")
	data_all=pd.read_csv("StimEigen_RT_keep.csv")	
	fig=plt.figure(figsize=(5, 6))
	ax=fig.add_subplot(1	
	axx=sns.regplot(data_all.ix[:,0], data_all.ix[:,1], color='Black', fit_reg=True, robust=True, label='All, r=.326**', scatter=True, ci=None, scatter_kws={'s':2}, ax=ax)
	axx=sns.regplot(datah.ix[:,0], datah.ix[:,1], color='Red', fit_reg=True, robust=True, scatter=True, ci=None, scatter_kws={'s':35}, ax=ax)
	axx=sns.regplot(dataf.ix[:,0], dataf.ix[:,1], color='Blue', fit_reg=True, robust=True, scatter=True, ci=None, scatter_kws={'s':35}, ax=ax)
	axx=sns.regplot(datah.ix[:,0], datah.ix[:,1], color='Red', fit_reg=True, robust=True, scatter=True, ci=None, scatter_kws={'s':35}, ax=ax)
	axx=sns.regplot(dataf.ix[:,0], dataf.ix[:,1], color='Blue', fit_reg=True, robust=True, scatter=True, ci=None, scatter_kws={'s':35}, ax=ax)
	axx=sns.regplot(dataf.ix[:,0], dataf.ix[:,1], color='Blue', fit_reg=True, robust=True, label='Face, r=.320*', scatter=True, ci=None, scatter_kws={'s':35}, ax=ax)
	axx=sns.regplot(datah.ix[:,0], datah.ix[:,1], color='Red', fit_reg=True, robust=True, label='House, r=.333*', scatter=True, ci=None, scatter_kws={'s':35}, ax=	
	fig.set_tight_layout(True)
	fig.subplots_adjust(left=.22, bottom=.14, top=.95, right=.7)
	ax.set_ylim([-1,1])
	ax.set_xlim([2,14])
	#ax.set_xticklabels(np.arange(2, 16, 2), fontsize=16)
	ax.set_xticklabels(np.arange(2, 16, 2), fontsize=10)
	ax.set_xlabel("Distance to Category Boundary", fontsize=12, labelpad	
	leg = ax.legend(loc='best', fancybox=True, fontsize=10)
	leg.get_frame().set_alpha(0.	
	#ax.legend(loc=0, fontsize=14)
	#plt.tight_layou	
	ax.set_ylabel("Response Time (s)", fontsize=12, labelpad=5)
	ax.set_yticklabels(np.arange(-1, 1.5, 0.5), fontsize=10)
	sns.despine()
	#plt.tight_layout(pad=2)
	#plt.subplots_adjust(left=.22, bottom=.14, top=.95, right=.7)
	plt.savefig(figname+".png", format='png', dpi=6	
	return fig, ax

def plot_correl_bycue(ax=None, figname="correlbycue_plot"):

	sns.set_style("white")
	sns.set_style("white", {"legend.scatterpoints": 1, "legend.frameon":True})

	df=pd.read_csv("/Users/kyle/Desktop/beh_hddm/MDS_Analysis/dist_RTxCue_allcor.csv")

	dataf=df[df['stim']=='face']
	datah=df[df['stim']=='house']

	fig=plt.figure(figsize=(10, 12))
	axf=fig.add_subplot(121)
	axh=fig.add_subplot(122)

	axx=sns.regplot(dataf['distance'], dataf['hcRT'], color='Red', fit_reg=True, robust=True, label='House Cue, r=-.19', scatter=True, ci=None, scatter_kws={'s':35}, ax=axf)
	axx=sns.regplot(dataf['distance'], dataf['ncRT'], color='Black', fit_reg=True, robust=False, label='Neutral Cue, r=-.15', scatter=True, ci=None, scatter_kws={'s':35}, ax=axf)
	axx=sns.regplot(dataf['distance'], dataf['fcRT'], color='Blue', fit_reg=True, robust=True, label='Face Cue, r=-.320*', scatter=True, ci=None, scatter_kws={'s':35}, ax=axf)

	axx=sns.regplot(datah['distance'], datah['hcRT'], color='Red', fit_reg=True, robust=True, label='House Cue, r=-.330*', scatter=True, ci=None, scatter_kws={'s':35}, ax=axh)
	axx=sns.regplot(datah['distance'], datah['ncRT'], color='Black', fit_reg=True, robust=True, label='Neutral Cue, r=-.18', scatter=True, ci=None, scatter_kws={'s':35}, ax=axh)
	axx=sns.regplot(datah['distance'], datah['fcRT'], color='Blue', fit_reg=True, robust=True, label='face Cue, r=-.09', scatter=True, ci=None, scatter_kws={'s':35}, ax=axh)

	#fig.set_tight_layout(True)
	#fig.subplots_adjust(left=.22, bottom=.14, top=.95, right=.7)
	for ax in fig.axes:
		ax.set_ylim([-1.2,1.2])
		ax.set_xlim([-5,18])
		#ax.set_xticklabels(np.arange(2, 16, 2), fontsize=16)
		#axf.set_xticklabels(np.arange(2, 16, 2), fontsize=10)
		ax.set_xlabel("Distance to Category Boundary", fontsize=12, labelpad=5)
	
		leg = ax.legend(loc='best', fancybox=True, fontsize=10)
		leg.get_frame().set_alpha(0.95)
	
		#ax.legend(loc=0, fontsize=14)
		#plt.tight_layout()
	
		ax.set_ylabel("Response Time (s)", fontsize=12, labelpad=5)
		#ax.set_yticklabels(np.arange(-1, 1.5, 0.5), fontsize=10)
		sns.despine()
		#plt.tight_layout(pad=2)
		#plt.subplots_adjust(left=.22, bottom=.14, top=.95, right=.7)
	
	plt.savefig(figname+".png", format='png', dpi=600)

	return fig


def plot_rho_heatmap():
	
	sns.set_style("white")
	pal=sns.blend_palette(['Darkred', 'Pink'], as_cmap=True)
	
	df=pd.read_csv("/Users/kyle/Desktop/beh_hddm/MDS_Analysis/dist_RTxCue_allcor.csv")
	
	dataf=df[df['stim']=='face']
	datah=df[df['stim']=='house']

	fhc=dataf['distance'].corr(dataf['hcRT'], method='spearman')
	fnc=dataf['distance'].corr(dataf['ncRT'], method='spearman')
	ffc=dataf['distance'].corr(dataf['fcRT'], method='spearman')
	hhc=datah['distance'].corr(datah['hcRT'], method='spearman')
	hnc=datah['distance'].corr(datah['ncRT'], method='spearman')
	hfc=datah['distance'].corr(datah['fcRT'], method='spearman')
	
	fcorr=np.array([fhc, fnc, ffc])
	hcorr=np.array([hhc, hnc, hfc])
	
	corr_matrix=np.array([fcorr, hcorr])
	
	fig=plt.figure(figsize=(10,8))
	fig.set_tight_layout(True)	
	
	ax=fig.add_subplot(111)
	
	fig.subplots_adjust(top=.95, hspace=.1, left=0.10, right=.9, bottom=0.1)

	ax.set_ylim(-0.5, 1.5)
	ax.set_yticks([0, 1])
	ax.set_yticklabels(['Face', 'House'], fontsize=24)
	plt.setp(ax.get_yticklabels(), rotation=90)
	ax.set_ylabel("Stimulus", fontsize=28, labelpad=8)
	ax.set_xlim(-0.5, 2.5)
	ax.set_xticks([0, 1, 2])
	ax.set_xticklabels(['House', 'Neutral', 'Face'], fontsize=24)
	ax.set_xlabel("Cue Type", fontsize=28, labelpad=8)
	ax_map=ax.imshow(corr_matrix, interpolation='nearest', cmap=pal, origin='lower', vmin=-0.40, vmax=0)
	plt.colorbar(ax_map, ax=ax, shrink=0.65)
	
	for i, cond in enumerate(corr_matrix):
		x=0
		for xval in cond:
			if -.35<xval<=-.30:
				ax.text(x, i, "r="+str(xval)[:5]+"*", ha='center', va='center', fontsize=29)
			elif xval<-.35:
				ax.text(x, i, "r="+str(xval)[:5]+"**", ha='center', va='center', fontsize=29)
			else:
				ax.text(x, i, "r="+str(xval)[:5], ha='center', va='center', fontsize=22)
			x+=1
	
	plt.savefig('corr.png', format='png', dpi=600)		
			branches = mytree.branchnames
			sns.corrplot(arr,names=branches, ax=CorrelationAxes, annot=False, diag_names=False)
			CorrelationAxes.tick_params(axis='both', labelsize=4)

			CorrelationAxes.set_xticklabels( [branchname.replace("_"," ") for branchname in branches] )
			CorrelationAxes.set_yticklabels( [branchname.replace("_"," ") for branchname in branches] )

			CorrelationAxes.annotate("Correlation Matrix: "+label, xy=(0.5, 0.95), xycoords='axes fraction', fontweight='bold', fontsize=10)
			CorrelationAxes.annotate(r"\textbf{\textit{ATLAS}} Internal", xy=(0.6, 0.05), xycoords='axes fraction', fontweight='bold', fontsize=10)

			plt.savefig("plots/var_correlation_%s.pdf"%(label)  )
			# code.interact(local=locals())



		tmpcmap = sns.blend_palette([ (1,1,1,0), colorpal[iSample] + (1,) ], 10)
		tmpcmap =  mpl.colors.ListedColormap(tmpcmap)

		OnePlotFig = plt.figure(figsize=(6, 6), dpi=100, facecolor='white')
		OnePlotAxes  = OnePlotFig.add_axes(rect_scatter)
		OnePlotAxesX = OnePlotFig.add_axes(rect_histx)
		OnePlotAxesY = OnePlotFig.add_axes(rect_histy)

		OnePlotAxesX.yaxis.grid()
		OnePlotAxesY.xaxis.grid()
		OnePlotAxesY.set_xticklabels([])
		OnePlotAxesY.set_yticklabels([])
		OnePlotAxesX.set_xticklabels([])
		OnePlotAxesX.set_yticklabels([])
		OnePlotAxesX.spines['top'].set_visible(False)
		OnePlotAxesY.spines['right'].set_visible(False)