예제 #1
0
def make_color_bar_images(directory=IMG_DIR, width=1.0, height=0.2):
    """
    Create color bar images for use in selection drop-downs

    Arguments
    ---------

    directory: string
              directory to store images

    width: float
           inches for the width of the color ramp image

    height: float
           inches for the height of the color ramp image


    Notes
    -----
    Creates one image for each color map in the colorbrewer schemes from palettable.
    Each image will be named `cmap_k.png` where cmap is the name of cmap from palettable, and k is the number of classes
    """
    for ctype_key in ['Diverging', 'Sequential', 'Qualitative']:
        ctype = colorbrewer.COLOR_MAPS[ctype_key]
        for cmap_key, cmap in ctype.items():
            for k, cmap_k in cmap.items():
                cmap = colorbrewer.get_map(cmap_key, ctype_key, int(k))
                fname = "{dir}/{cmap_key}_{k}.png".format(dir=directory,
                                                          cmap_key=cmap_key,
                                                          k=k)
                cmap.save_discrete_image(filename=fname, size=(width, height))
예제 #2
0
def get_hex_colors(cmap, ctype, k):
    """return list of hex colors for cmap

    Arguments
    ---------

    cmap: string
          Blues, PrGn,......RdBu

    ctype: string
           Sequential, Diverging, Qualitative

    k: int
       number of classes/colors

    Returns
    -------
    list hex codes for k colors

    Example
    -------
    >>> get_hex_colors('Blues', 'sequential', 5)
    ['#EFF3FF', '#BDD7E7', '#6BAED6', '#3182BD', '#08519C']

    >>> get_hex_colors('sequential', 'Blues', 5)
    Cmap not defined: sequential Blues 5

    """
    try:
        return colorbrewer.get_map(cmap, ctype, k).hex_colors
    except:
        print('Cmap not defined:', cmap, ctype, k)
예제 #3
0
파일: strings.py 프로젝트: escherba/lsh-hdc
def create_plots(args, df, metrics):
    import matplotlib.pyplot as plt
    from palettable import colorbrewer
    from matplotlib.font_manager import FontProperties
    fontP = FontProperties()
    fontP.set_size('small')

    groups = df.groupby([args.group_by])
    palette_size = min(max(len(groups), 3), 9)
    for metric in metrics:
        if metric in df:
            colors = cycle(colorbrewer.get_map('Set1', 'qualitative', palette_size).mpl_colors)
            fig, ax = plt.subplots()
            for color, (label, dfel) in izip(colors, groups):
                try:
                    dfel.plot(
                        ax=ax, label=label, x=args.x_axis, linewidth='1.3',
                        y=metric, kind="scatter", logx=True, title=args.fig_title,
                        facecolors='none', edgecolors=color)
                except Exception:
                    logging.exception("Exception caught plotting %s:%s", metric, label)
            fig_filename = "fig_%s.%s" % (metric, args.fig_format)
            fig_path = os.path.join(args.output, fig_filename)
            ax.legend(prop=fontP, **LEGEND_METRIC_KWARGS.get(metric, {'loc': 'lower right'}))
            fig.savefig(fig_path)
            plt.close(fig)
예제 #4
0
    def compare(self, others, scores, dtype=np.float16, plot=False):
        result0 = self.compute(scores, dtype=dtype)

        if not isiterable(others):
            others = [others]

        result_grid = []
        for other in others:
            result1 = other.compute(scores, dtype=dtype)

            if plot:
                from matplotlib import pyplot as plt
                from palettable import colorbrewer

                colors = colorbrewer.get_map("Set1", "qualitative", 9).mpl_colors

            result_row = {}
            for score_name, scores0 in result0.iteritems():
                scores1 = result1[score_name]
                auc_score = dist_auc(scores0, scores1)
                result_row[score_name] = auc_score
                if plot:
                    scores0p = [x for x in scores0 if not np.isnan(x)]
                    scores1p = [x for x in scores1 if not np.isnan(x)]
                    hmin0, hmax0 = minmaxr(scores0p)
                    hmin1, hmax1 = minmaxr(scores1p)
                    bins = np.linspace(min(hmin0, hmin1), max(hmax0, hmax1), 50)
                    plt.hist(scores0p, bins, alpha=0.5, label="0", color=colors[0], edgecolor="none")
                    plt.hist(scores1p, bins, alpha=0.5, label="1", color=colors[1], edgecolor="none")
                    plt.legend(loc="upper right")
                    plt.title("%s: AUC=%.4f" % (score_name, auc_score))
                    plt.show()
            result_grid.append(result_row)
        return result_grid
예제 #5
0
def plot_proportions(alphas):
    """
    Given alpha (T, K), plot the overall popularity of each topic over time from left to right
    """
    T, K = alphas.shape
    colors = colorbrewer.get_map('Set3', map_type='qualitative', number=min(K, 12)).mpl_colors
    colors = list(itr.islice(itr.cycle(colors), K))
    W = .5
    H = 5
    PW = .2
    TOTAL_W = T * (W + PW)
    plt.figure(figsize=(TOTAL_W, H))
    ax = plt.gca()
    plt.xlim(0, TOTAL_W)
    plt.ylim(0, H)
    plt.axis('off')

    top_by_time = [np.cumsum(alphas[t]) for t in range(T)]
    for t in range(T):
        x = t * (W + PW)
        for k in range(K):
            prop = alphas[t][k]
            y = top_by_time[t][k] - prop
            color = colors[k]
            r = mpatches.Rectangle(xy=(x, y*H), width=W, height=prop*H, edgecolor='k', facecolor=color)
            ax.add_patch(r)

            if t != 0:
                left_y = top_by_time[t-1][k] - alphas[t-1][k]/2
                right_y = y + prop/2
                ax.arrow(x - PW, left_y*H, dx=PW*3/4, dy=(right_y - left_y)*H, head_length=PW/4, head_width=.1)

    return plt
예제 #6
0
 def plot_matplotlib(self, theme="Paired", scale_grid=False):
     """Use matplotlib to plot a phenotype phase plane in 3D.
     theme: color theme to use (requires palettable)
     returns: maptlotlib 3d subplot object"""
     if pyplot is None:
         raise ImportError("Error importing matplotlib 3D plotting")
     colors = empty(self.growth_rates.shape, dtype=dtype((str, 7)))
     n_segments = self.segments.max()
     # pick colors
     color_list = [
         '#A6CEE3', '#1F78B4', '#B2DF8A', '#33A02C', '#FB9A99', '#E31A1C',
         '#FDBF6F', '#FF7F00', '#CAB2D6', '#6A3D9A', '#FFFF99', '#B15928'
     ]
     if get_map is not None:
         try:
             color_list = get_map(theme, 'Qualitative',
                                  n_segments).hex_colors
         except ValueError:
             from warnings import warn
             warn('palettable could not be used for this number of phases')
     if n_segments > len(color_list):
         from warnings import warn
         warn("not enough colors to color all detected phases")
     if n_segments > 0 and n_segments <= len(color_list):
         for i in range(n_segments):
             colors[self.segments == (i + 1)] = color_list[i]
     else:
         colors[:, :] = 'b'
     if scale_grid:
         # grid wires should not have more than ~20 points
         xgrid_scale = int(self.reaction1_npoints / 20)
         ygrid_scale = int(self.reaction2_npoints / 20)
     else:
         xgrid_scale, ygrid_scale = (1, 1)
     figure = pyplot.figure()
     xgrid, ygrid = meshgrid(self.reaction1_fluxes, self.reaction2_fluxes)
     axes = figure.add_subplot(111, projection="3d")
     xgrid = xgrid.transpose()
     ygrid = ygrid.transpose()
     axes.plot_surface(xgrid,
                       ygrid,
                       self.growth_rates,
                       rstride=1,
                       cstride=1,
                       facecolors=colors,
                       linewidth=0,
                       antialiased=False)
     axes.plot_wireframe(xgrid,
                         ygrid,
                         self.growth_rates,
                         color="black",
                         rstride=xgrid_scale,
                         cstride=ygrid_scale)
     axes.set_xlabel(self.reaction1_name, size="x-large")
     axes.set_ylabel(self.reaction2_name, size="x-large")
     axes.set_zlabel("Growth rate", size="x-large")
     axes.view_init(elev=30, azim=-135)
     figure.set_tight_layout(True)
     return axes
예제 #7
0
    def plot_matplotlib(self, theme="Paired", scale_grid=False):
        """Use matplotlib to plot a phenotype phase plane in 3D.

        theme: color theme to use (requires palettable)

        returns: maptlotlib 3d subplot object"""
        if pyplot is None:
            raise ImportError("Error importing matplotlib 3D plotting")
        colors = empty(self.growth_rates.shape, dtype="|S7")
        n_segments = self.segments.max()
        # pick colors
        if get_map is None:
            color_list = [
                "#A6CEE3",
                "#1F78B4",
                "#B2DF8A",
                "#33A02C",
                "#FB9A99",
                "#E31A1C",
                "#FDBF6F",
                "#FF7F00",
                "#CAB2D6",
                "#6A3D9A",
                "#FFFF99",
                "#B15928",
            ]
        else:
            color_list = get_map(theme, "Qualitative", n_segments).hex_colors
        if n_segments > len(color_list):
            from warnings import warn

            warn("not enough colors to color all detected phases")
        if n_segments > 0 and n_segments <= len(color_list):
            for i in range(n_segments):
                colors[self.segments == (i + 1)] = color_list[i]
        else:
            colors[:, :] = "b"
        if scale_grid:
            # grid wires should not have more than ~20 points
            xgrid_scale = int(self.reaction1_npoints / 20)
            ygrid_scale = int(self.reaction2_npoints / 20)
        else:
            xgrid_scale, ygrid_scale = (1, 1)
        figure = pyplot.figure()
        xgrid, ygrid = meshgrid(self.reaction1_fluxes, self.reaction2_fluxes)
        axes = figure.add_subplot(111, projection="3d")
        xgrid = xgrid.transpose()
        ygrid = ygrid.transpose()
        axes.plot_surface(
            xgrid, ygrid, self.growth_rates, rstride=1, cstride=1, facecolors=colors, linewidth=0, antialiased=False
        )
        axes.plot_wireframe(xgrid, ygrid, self.growth_rates, color="black", rstride=xgrid_scale, cstride=ygrid_scale)
        axes.set_xlabel(self.reaction1_name)
        axes.set_ylabel(self.reaction2_name)
        axes.set_zlabel("Growth rate")
        axes.view_init(elev=30, azim=-135)
        return axes
예제 #8
0
파일: palettes.py 프로젝트: has2k1/mizani
    def _brewer_pal(n):
        # Only draw the maximum allowable colors from the palette
        # and fill any remaining spots with None
        _n = n if n <= n_max else n_max
        try:
            bmap = colorbrewer.get_map(palette_name, type, _n)
        except ValueError:
            # Some palettes have a minimum no. of colors
            # We get around that restriction.
            n_min = brewer_helper.min_num_colors(type, palette_name)
            bmap = colorbrewer.get_map(palette_name, type, n_min)

        hex_colors = bmap.hex_colors[:n]
        if n > n_max:
            msg = ("Warning message:"
                   f"Brewer palette {palette_name} has a maximum "
                   f"of {n_max} colors Returning the palette you "
                   "asked for with that many colors")
            warn(msg)
            hex_colors = hex_colors + [None] * (n - n_max)
        return hex_colors[::direction]
예제 #9
0
    def plot_matplotlib(self, theme="Paired", scale_grid=False):
        """Use matplotlib to plot a phenotype phase plane in 3D.

        theme: color theme to use (requires palettable)

        returns: maptlotlib 3d subplot object"""
        if pyplot is None:
            raise ImportError("Error importing matplotlib 3D plotting")
        colors = empty(self.growth_rates.shape, dtype=dtype((str, 7)))
        n_segments = self.segments.max()
        # pick colors
        color_list = ['#A6CEE3', '#1F78B4', '#B2DF8A', '#33A02C',
                      '#FB9A99', '#E31A1C', '#FDBF6F', '#FF7F00',
                      '#CAB2D6', '#6A3D9A', '#FFFF99', '#B15928']
        if get_map is not None:
            try:
                color_list = get_map(theme, 'Qualitative',
                                     n_segments).hex_colors
            except ValueError:
                from warnings import warn
                warn('palettable could not be used for this number of phases')
        if n_segments > len(color_list):
            from warnings import warn
            warn("not enough colors to color all detected phases")
        if n_segments > 0 and n_segments <= len(color_list):
            for i in range(n_segments):
                colors[self.segments == (i + 1)] = color_list[i]
        else:
            colors[:, :] = 'b'
        if scale_grid:
            # grid wires should not have more than ~20 points
            xgrid_scale = int(self.reaction1_npoints / 20)
            ygrid_scale = int(self.reaction2_npoints / 20)
        else:
            xgrid_scale, ygrid_scale = (1, 1)
        figure = pyplot.figure()
        xgrid, ygrid = meshgrid(self.reaction1_fluxes, self.reaction2_fluxes)
        axes = figure.add_subplot(111, projection="3d")
        xgrid = xgrid.transpose()
        ygrid = ygrid.transpose()
        axes.plot_surface(xgrid, ygrid, self.growth_rates, rstride=1,
                          cstride=1, facecolors=colors, linewidth=0,
                          antialiased=False)
        axes.plot_wireframe(xgrid, ygrid, self.growth_rates, color="black",
                            rstride=xgrid_scale, cstride=ygrid_scale)
        axes.set_xlabel(self.reaction1_name, size="x-large")
        axes.set_ylabel(self.reaction2_name, size="x-large")
        axes.set_zlabel("Growth rate", size="x-large")
        axes.view_init(elev=30, azim=-135)
        figure.set_tight_layout(True)
        return axes
예제 #10
0
파일: palettes.py 프로젝트: has2k1/mizani
    def _brewer_pal(n):
        # Only draw the maximum allowable colors from the palette
        # and fill any remaining spots with None
        _n = n if n <= nmax else nmax
        try:
            bmap = colorbrewer.get_map(palette_name, type, _n)
        except ValueError as err:
            # Some palettes have a minimum no. of colors set at 3
            # We get around that restriction.
            if 0 <= _n < 3:
                bmap = colorbrewer.get_map(palette_name, type, 3)
            else:
                raise err

        hex_colors = bmap.hex_colors[:n]
        if n > nmax:
            msg = ("Warning message:"
                   "Brewer palette {} has a maximum of {} colors"
                   "Returning the palette you asked for with"
                   "that many colors".format(palette_name, nmax))
            warnings.warn(msg)
            hex_colors = hex_colors + [None] * (n - nmax)
        return hex_colors
예제 #11
0
    def _brewer_pal(n):
        # Only draw the maximum allowable colors from the palette
        # and fill any remaining spots with None
        _n = n if n <= nmax else nmax
        try:
            bmap = colorbrewer.get_map(palette_name, type, _n)
        except ValueError as err:
            # Some palettes have a minimum no. of colors set at 3
            # We get around that restriction.
            if 0 <= _n < 3:
                bmap = colorbrewer.get_map(palette_name, type, 3)
            else:
                raise err

        hex_colors = bmap.hex_colors[:n]
        if n > nmax:
            msg = ("Warning message:"
                   "Brewer palette {} has a maximum of {} colors"
                   "Returning the palette you asked for with"
                   "that many colors".format(palette_name, nmax))
            warnings.warn(msg)
            hex_colors = hex_colors + [None] * (n - nmax)
        return hex_colors
예제 #12
0
def brewer_colour(id_col, map_type='diverging', name='RdYlBu', number=11):
    """

    Args:
        id_col (int):
        map_type (str='diverging'):
        name (str='RdYlBu'):
        number (int=11):

    Returns:
        hex (str), name (str)
    """
    hex = get_map(name, map_type, number).hex_colors[id_col - 1]
    str_format = f'0{len(str(number))}d'
    name = f'brewer({map_type}.{name}_{number}[{format(id_col, str_format)}])'

    return hex, name
예제 #13
0
    def compare(self, others, scores, dtype=np.float16, plot=False):
        result0 = self.compute(scores, dtype=dtype)

        if not isiterable(others):
            others = [others]

        result_grid = []
        for other in others:
            result1 = other.compute(scores, dtype=dtype)

            if plot:
                from matplotlib import pyplot as plt
                from palettable import colorbrewer
                colors = colorbrewer.get_map('Set1', 'qualitative',
                                             9).mpl_colors

            result_row = {}
            for score_name, scores0 in result0.iteritems():
                scores1 = result1[score_name]
                auc_score = dist_auc(scores0, scores1)
                result_row[score_name] = auc_score
                if plot:
                    scores0p = [x for x in scores0 if not np.isnan(x)]
                    scores1p = [x for x in scores1 if not np.isnan(x)]
                    hmin0, hmax0 = minmaxr(scores0p)
                    hmin1, hmax1 = minmaxr(scores1p)
                    bins = np.linspace(min(hmin0, hmin1), max(hmax0, hmax1),
                                       50)
                    plt.hist(scores0p,
                             bins,
                             alpha=0.5,
                             label='0',
                             color=colors[0],
                             edgecolor="none")
                    plt.hist(scores1p,
                             bins,
                             alpha=0.5,
                             label='1',
                             color=colors[1],
                             edgecolor="none")
                    plt.legend(loc='upper right')
                    plt.title("%s: AUC=%.4f" % (score_name, auc_score))
                    plt.show()
            result_grid.append(result_row)
        return result_grid
예제 #14
0
def waveplot(lat, lon, ts, te, Var, overlay):

    #m.bluemarble(scale=0.5);
    #Vars=Var[ts:te,:]
    Vars = Var[ts, :, :]
    #print len(lon)
    #print len(lat)
    cmap = colorbrewer.get_map('RdYlGn', 'diverging', 11,
                               reverse=True).mpl_colormap
    fig, ax = gearth_fig(llcrnrlon=lons.min(),
                         llcrnrlat=lats.min(),
                         urcrnrlon=lons.max(),
                         urcrnrlat=lats.max(),
                         pixels=pixels)

    (x, y) = m(lon, lat)
    cs = m.pcolormesh(x, y, Vars, cmap=cmap)
    #cbar = plt.colorbar(cs,location='right')
    #plt.title('Var forecast')
    #ax.set_axis_off()
    fig.savefig(overlay, transparent=False, format='png')

    return overlay
예제 #15
0
# https://nbviewer.jupyter.org/github/ocefpaf/PIRATA/blob/master/CTD-PIRATA-Processing.ipynb

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER
from palettable import colorbrewer

LAND = colorbrewer.get_map("Greens", "sequential", 9)
OCEAN = colorbrewer.get_map("Blues", "sequential", 9, reverse=True)
LAND_OCEAN = np.array(OCEAN.mpl_colors + LAND.mpl_colors)


def make_map(extent, figsize=(12, 12), projection=ccrs.PlateCarree()):
    fig, ax = plt.subplots(figsize=figsize,
                           subplot_kw={"projection": projection})
    ax.set_extent(extent)
    gl = ax.gridlines(draw_labels=True)
    gl.xlabels_top = gl.ylabels_right = False
    gl.ylines = gl.xlines = False
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER
    ax.coastlines(resolution="50m")
    return fig, ax


def add_etopo2(extent, ax, levels=None):
    import iris

    url = (
        "http://gamone.whoi.edu/thredds/dodsC/usgs/data0/bathy/ETOPO2v2c_f4.nc"
예제 #16
0
#print ('Images(' + str(len(tim.data)-1) + ")")
print('Images(' + str(len((tim.data) - 1) / 24) + ")")

while i < len(tim.data) - 1:
    try:
        #image = odplot(z, lat, lon, i, i+1, mass_oil, ourfile + str(i) + '.png')
        image = waveplot(lats, lons, i, i + 1, Var, ourfile + str(i) + '.png')
        #i = i+1
        i = i + 24
        images.append(image)
    except:
        break
Vars = Var[0, :, :]
#cs=plt.contourf(lons, lats, Vars, levels=[-1,-0.5,0,0.1,0.2,0.3,0.4,0.6,0.8,1])
(x, y) = m(lons, lats)
cmap = colorbrewer.get_map('RdYlGn', 'diverging', 11,
                           reverse=True).mpl_colormap
cs = m.pcolormesh(x, y, Vars, cmap=cmap)

fig = plt.figure(figsize=(2.0, 7.0), facecolor=None, frameon=True)
ax = fig.add_axes([0.0, 0.05, 0.2, 0.9])
cb = fig.colorbar(cs, cax=ax)
cb.set_label('Significant wave height (WAM8km model at 00Z) from MET Norway',
             rotation=-90,
             color='k',
             labelpad=20)
fig.savefig('legend.png', transparent=True, format='png'
            )  # Change transparent to True if your colorbar is not on space :)

kml = make_kml(llcrnrlon=lons.min(),
               llcrnrlat=lats.min(),
               urcrnrlon=lons.max(),