示例#1
0
def plot_colors(c, figsize=(10, 1), ticks=False, **kwargs):
    """
    Plot a horizontal colorbar to inspect colors

    Parameters
    ----------
    c : array-like or Colormap instance
        Iterable containing colors. A :obj:`~numpy.ndarray` with 3 dimensions will be interpreted as RBA(A).
    figsize : tuple, optional
        Matplotlib figsize parameter
    ticks : bool, optional
        Add ticks to the figure
    **kwargs
        Parameters for `plt.ColorBase`
    """
    plt.rcParams['savefig.pad_inches'] = 0

    if isinstance(c, (list, tuple, np.ndarray)):
        c = np.array(c)
        N = c.shape[0]
        cmap = LinearSegmentedColormap.from_list('plot', c, N=N)
    elif isinstance(c, colors.Colormap):
        N = c.N
        cmap = c

    fig, ax = plt.subplots(figsize=figsize)

    bounds = np.linspace(0, 1, N + 1)
    if N < 256:
        norm = colors.BoundaryNorm(bounds, N)

    cb = ColorbarBase(ax,
                      cmap=cmap,
                      norm=norm,
                      spacing='proportional',
                      ticks=None,
                      boundaries=bounds,
                      format='%1i',
                      orientation=u'horizontal',
                      **kwargs)
    ax.patch.set_edgecolor('black')

    if not ticks:
        plt.tick_params(
            axis='x',  # changes apply to the x-axis
            which='both',  # both major and minor ticks are affected
            bottom=False,  # ticks along the bottom edge are off
            top=False,  # ticks along the top edge are off
            labelbottom=False)  # labels along the bottom edge are off
    else:
        cb.set_ticks(np.linspace(1 / (2 * N), 1 - 1 / (2 * N), N))
        cb.set_ticklabels(np.arange(N))
示例#2
0
def spikesplot_cb(position, cmap='viridis', fig=None):
    # Add colorbar
    if fig is None:
        fig = plt.gcf()

    cax = fig.add_axes(position)
    cb = ColorbarBase(cax, cmap=get_cmap(cmap), spacing='proportional',
                      orientation='horizontal', drawedges=False)
    cb.set_ticks([0, 0.5, 1.0])
    cb.set_ticklabels(['Inferior', '(axial slice)', 'Superior'])
    cb.outline.set_linewidth(0)
    cb.ax.xaxis.set_tick_params(width=0)
    return cax
示例#3
0
def spikesplot_cb(position, cmap='viridis', fig=None):
    # Add colorbar
    if fig is None:
        fig = plt.gcf()

    cax = fig.add_axes(position)
    cb = ColorbarBase(cax, cmap=cm.get_cmap(cmap), spacing='proportional',
                      orientation='horizontal', drawedges=False)
    cb.set_ticks([0, 0.5, 1.0])
    cb.set_ticklabels(['Inferior', '(axial slice)', 'Superior'])
    cb.outline.set_linewidth(0)
    cb.ax.xaxis.set_tick_params(width=0)
    return cax
示例#4
0
文件: verify.py 项目: menchant/bio96
def pick_colors(ax, df, attr, cmap):
    from matplotlib.colorbar import ColorbarBase

    colors = Colors(cmap, df[attr])

    bar = ColorbarBase(
        ax,
        norm=colors.norm,
        cmap=colors.cmap,
        boundaries=colors.boundaries,
    )
    bar.set_ticks(colors.ticks)
    bar.set_ticklabels(colors.ticklabels)

    ax.invert_yaxis()

    return colors
示例#5
0
def plot_colorbar(cax, cmap, cnorm, hue_norm, linewidth=0.5):
    if isinstance(cmap, str):
        cmap = get_cmap(cmap)

    colorbar = ColorbarBase(cax,
                            cmap=cmap,
                            norm=cnorm,
                            orientation='vertical',
                            extend='both')
    colorbar_ticks = [hue_norm[0], sum(hue_norm) / 2, hue_norm[1]]
    # TODO automatic ticklabel format, auto sci-format, float trim etc
    colorbar_ticklabels = list(map(lambda i: f'{i:.1f}', colorbar_ticks))
    colorbar.set_ticks(colorbar_ticks)
    colorbar.set_ticklabels(colorbar_ticklabels)
    colorbar.outline.set_linewidth(linewidth)
    colorbar.ax.tick_params(size=1, pad=1, width=linewidth, length=0.3)
    return cax
示例#6
0
    def display_median_price_animation(self):
        """Kicks off the animation of median price information."""
        fig = plotter.figure(num=1, figsize=(10, 12), tight_layout=True)
        fig.canvas.set_window_title('Percent increase in median house ' + \
                                    'price since 1996')

        axis = fig.add_axes([0.85, 0.04, 0.03, 0.92])
        colorbar_ticks = [0, .2, .4, .6, .8, 1.0]
        colorbar_labels = ['-100%', '0%', '250%', '500%', '750%', '>1000%']
        colorbar = ColorbarBase(axis, self._colormap, orientation='vertical')
        colorbar.set_ticks(colorbar_ticks)
        colorbar.set_ticklabels(colorbar_labels)

        fig.add_axes([0.0, 0.0, 0.82, 1.0])
        anim = FuncAnimation(fig,
                             self._animate,
                             frames=self.endyear + 1 - self.startyear,
                             interval=1000,
                             blit=True,
                             init_func=self._init_animate,
                             repeat_delay=3000)
        plotter.show()
示例#7
0
    def display_median_price_animation(self):
        """Kicks off the animation of median price information."""
        fig = plotter.figure(num = 1, figsize = (10, 12), tight_layout = True)
        fig.canvas.set_window_title('Percent increase in median house ' + \
                                    'price since 1996')

        axis = fig.add_axes([0.85, 0.04, 0.03, 0.92])
        colorbar_ticks = [0, .2, .4, .6, .8, 1.0]
        colorbar_labels = ['-100%', '0%', '250%', '500%', '750%', '>1000%']
        colorbar = ColorbarBase(axis, self._colormap, orientation='vertical')
        colorbar.set_ticks(colorbar_ticks)
        colorbar.set_ticklabels(colorbar_labels)

        fig.add_axes([0.0, 0.0, 0.82, 1.0])
        anim = FuncAnimation(fig,
                             self._animate,
                             frames = self.endyear + 1 - self.startyear,
                             interval = 1000,
                             blit = True,
                             init_func = self._init_animate,
                             repeat_delay = 3000)
        plotter.show()
示例#8
0
def init():
    fig, (ax, cbar_ax) = plt.subplots(ncols=2,
                                      gridspec_kw=dict(width_ratios=(0.95,
                                                                     0.05)))
    ax.set_aspect(1.)
    fig.subplots_adjust(right=0.9)

    norm = Normalize()
    norm.autoscale(VALS)
    color_vals = [norm(c) for c in VALS]
    c_bar = ColorbarBase(
        cbar_ax,
        values=VALS,
        cmap=plt.get_cmap(),
        norm=norm,
        boundaries=range(5),
        ticklocation='right',
        ticks=[x + 0.5 for x in range(4)],
    )
    c_bar.solids.set_edgecolor("k")
    c_bar.set_ticklabels(VALS)
    return fig, ax
示例#9
0
def _plot(func, result, *, ax=None, add_cbar=True, **kwargs):
    # create ax if it does not exist
    if ax is None:
        fig = plt.figure(figsize=[4, 4])
        ax = fig.add_subplot(111)
    # else just get the figure
    else:
        fig = ax.figure

    xlim = [0, 1]
    ylim = [0, 1]
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks(xlim)
    ax.set_yticks(ylim)
    ax.set_xticklabels(result.limits[0])
    ax.set_yticklabels(result.limits[1])

    ax, cmap, norm, vals = func(result, ax=ax, **kwargs)

    if add_cbar:
        fig.subplots_adjust(right=0.9)
        cbar_ax = fig.add_axes([0.95, 0.1, 0.04, 0.8])

        color_vals = [norm(c) for c in vals]
        cbar_cmap = ListedColormap([cmap(v) for v in color_vals])
        c_bar = ColorbarBase(
            cbar_ax,
            cmap=cbar_cmap,
            norm=norm,
            ticks=np.linspace(min(vals), max(vals), 2 * len(vals) + 1)[1::2],
        )
        c_bar.solids.set_edgecolor("k")
        c_bar.set_ticklabels(vals)

    return fig
示例#10
0
def Plot_Compare_PerpYear(num_clusters,
                          bmus_values_sim,
                          bmus_dates_sim,
                          bmus_values_hist,
                          bmus_dates_hist,
                          n_sim=1,
                          month_ini=1,
                          show=True):
    '''
    Plot simulated - historical bmus comparison in a perpetual year

    bmus_dates requires 1 day resolution time
    bmus_values set min value has to be 1 (not 0)
    '''

    # check dates have 1 day time resolution
    td_h = bmus_dates_hist[1] - bmus_dates_hist[0]
    td_s = bmus_dates_sim[1] - bmus_dates_sim[0]
    if td_h.days != 1 or td_s.days != 1:
        print('PerpetualYear bmus comparison skipped.')
        print('timedelta (days): Hist - {0}, Sim - {1})'.format(
            td_h.days, td_s.days))
        return

    # plot figure
    fig, (ax_hist, ax_sim) = plt.subplots(2,
                                          1,
                                          figsize=(_faspect * _fsize, _fsize))

    # historical perpetual year
    axplot_PerpYear(
        ax_hist,
        num_clusters,
        bmus_values_hist,
        bmus_dates_hist,
        num_sim=1,
        month_ini=month_ini,
    )
    ax_hist.set_title('Historical')

    # simulated perpetual year
    axplot_PerpYear(
        ax_sim,
        num_clusters,
        bmus_values_sim,
        bmus_dates_sim,
        num_sim=n_sim,
        month_ini=month_ini,
    )
    ax_sim.set_title('Simulation')

    # add custom colorbar
    np_colors_int = GetClusterColors(num_clusters)
    ccmap = mcolors.ListedColormap([tuple(r) for r in np_colors_int])
    cax = fig.add_axes([0.92, 0.125, 0.025, 0.755])
    cbar = ColorbarBase(
        cax,
        cmap=ccmap,
        norm=mcolors.Normalize(vmin=0, vmax=num_clusters),
        ticks=np.arange(num_clusters) + 0.5,
    )
    cbar.ax.tick_params(labelsize=8)
    cbar.set_ticklabels(range(1, num_clusters + 1))

    # text
    fig.suptitle('Perpetual Year', fontweight='bold', fontsize=12)

    # show and return figure
    if show: plt.show()
    return fig
示例#11
0
def plot_colorbar(lims,
                  ticks=None,
                  ticklabels=None,
                  figsize=(1, 2),
                  labelsize='small',
                  ticklabelsize='x-small',
                  ax=None,
                  label='',
                  tickrotation=0.,
                  orientation='vertical',
                  end_labels=None,
                  colormap='mne',
                  transparent=False,
                  diverging=None):
    import matplotlib.pyplot as plt
    from matplotlib.colorbar import ColorbarBase
    from matplotlib.colors import Normalize
    from mne.viz._3d import _limits_to_control_points
    with plt.rc_context({
            'axes.labelsize': labelsize,
            'xtick.labelsize': ticklabelsize,
            'ytick.labelsize': ticklabelsize
    }):
        if diverging is None:
            diverging = (colormap == 'mne')  # simple heuristic here
        if diverging:
            use_lims = dict(kind='value', pos_lims=lims)
        else:
            use_lims = dict(kind='value', lims=lims)
        cmap, scale_pts, diverging, _, none_ticks = _limits_to_control_points(
            use_lims, 0, colormap, transparent, linearize=True)
        vmin, vmax = scale_pts[0], scale_pts[-1]
        if ticks is None:
            ticks = none_ticks
        del colormap, lims, use_lims
        adjust = (ax is None)
        if ax is None:
            fig, ax = plt.subplots(1, figsize=figsize)
        else:
            fig = ax.figure
        norm = Normalize(vmin=vmin, vmax=vmax)
        if ticklabels is None:
            ticklabels = ticks
        assert len(ticks) == len(ticklabels)
        cbar = ColorbarBase(ax,
                            cmap,
                            norm=norm,
                            ticks=ticks,
                            label=label,
                            orientation=orientation)
        for key in ('left', 'top',
                    'bottom' if orientation == 'vertical' else 'right'):
            ax.spines[key].set_visible(False)
        cbar.set_ticklabels(ticklabels)
        cbar.patch.set(facecolor='0.5', edgecolor='0.5')
        if orientation == 'horizontal':
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=tickrotation)
        else:
            plt.setp(ax.yaxis.get_majorticklabels(), rotation=tickrotation)
        cbar.outline.set_visible(False)
        lims = np.array(list(ax.get_xlim()) + list(ax.get_ylim()))
        if end_labels is not None:
            if orientation == 'horizontal':
                delta = np.diff(lims[:2]) * np.array([-0.05, 0.05])
                xs = np.array(lims[:2]) + delta
                has = ['right', 'left']
                ys = [lims[2:].mean()] * 2
                vas = ['center', 'center']
            else:
                xs = [lims[:2].mean()] * 2
                has = ['center'] * 2
                delta = np.diff(lims[2:]) * np.array([-0.05, 0.05])
                ys = lims[2:] + delta
                vas = ['top', 'bottom']
            for x, y, l, ha, va in zip(xs, ys, end_labels, has, vas):
                ax.text(x, y, l, ha=ha, va=va, fontsize=ticklabelsize)
        if adjust:
            fig.subplots_adjust(0.01, 0.05, 0.2, 0.95)
    return fig
示例#12
0
def plot_slice(vstruct,
               scentre,
               snormal,
               cell_size=None,
               contourf=True,
               cmap='viridis',
               cmap_range=(None, None),
               bval=np.nan,
               cbar_tick_rot=0,
               show_corners=True,
               orientation=None,
               alter_bbox=(0., 0., 0., 0.),
               angle_step=1.,
               dist_tol=1e-5,
               min_voxels=None,
               max_voxels=None):
    """

    Parameters
    ----------
    vstruct: dict
    scentre: Tuple
        point on slice plane (x, y, z)
    snormal: Tuple
        norma of slice plane (x, y, z)
    cell_size: float
        length of discretised cells. If None, cell_size = <minimum cube length> * 0.01
    contourf: bool or list of bools
        use filled contour, else lines. If list, set independantly for each element
    cmap: str or list of str
        cmap type.  If list, set independantly for each element
    cmap_range: tuple or list of tuples
        range for colors. If list, set independantly for each element
    bval: float or list of floats
        value to set as background for contour plots. If list, set independantly for each element
    cbar_tick_rot: float
        rotation of colorbar axis tick labels
    show_corners: bool of list of bool
        whether to show real space (x,y,z) plot corner positions. If list, set independantly for each element
    orientation: int or None
        between 0 and 3, select a specific bbox orientation (rotated by orientation * 90 degrees)
        if None, the orientation is selected such that corner min(x',y') -> min(x,y,z)
    alter_bbox: tuple of floats
        move edges of computed bbox (bottom, top, left, right)
    angle_step: float
        angular step (degrees) for mapping plane intersection with bounding box
    dist_tol: float
        distance tolerance for finding edge of bounding box
    min_voxels : int
        minimum number of voxels in cartesian density cube
    max_voxels : int
        maximum number of voxels in cartesian density cube

    Returns
    -------
    fig: matplotlib.figure.Figure
    final_axes: list
        [(ax, cbar_ax), ...] for each element

    """
    # cbar_fmt: matplotlib.ticker.Formatter
    #        formatter for converting colorbar tick labels to str, if None use scientific notation
    new_struct = apply_transforms(vstruct)

    acceptable_elements = [
        e for e in new_struct["elements"]
        if e["type"] in ["repeat_cell", "repeat_density"]
    ]
    num_elements = len(acceptable_elements)
    if num_elements == 0:
        raise ValueError(
            "no 'repeat_cell' or 'repeat_density' elements present in vstruct")

    if isinstance(contourf, bool):
        contourf = [contourf for _ in range(num_elements)]
    if not (isinstance(cmap, list) or isinstance(cmap, tuple)):
        cmap = [cmap for _ in range(num_elements)]
    if not (isinstance(bval, list) or isinstance(bval, tuple)):
        bval = [bval for _ in range(num_elements)]
    if not (isinstance(cmap_range, list)):
        cmap_range = [cmap_range for _ in range(num_elements)]
    if isinstance(show_corners, bool):
        show_corners = [show_corners for _ in range(num_elements)]

    # fig = plt.figure()
    # ax = fig.add_subplot(111, aspect='equal')  # type: Axes

    fig, axes = plt.subplots(1,
                             num_elements,
                             subplot_kw={"aspect": "equal"},
                             sharex='all',
                             sharey='all',
                             squeeze=True)
    fig = fig  # type: Figure
    if num_elements == 1:
        axes = [axes]

    final_axes = []

    for element, ax, el_contourf, el_cmap, el_bval, (
            el_vmin, el_vmax), el_corners in zip(new_struct["elements"], axes,
                                                 contourf, cmap, bval,
                                                 cmap_range, show_corners):
        ax = ax  # type: Axes

        if element["type"] == "repeat_density":
            cbar_title = "{0} ({1})".format(element["name"], element["dtype"])

            print("running cube_frac2cart")
            out = cube_frac2cart(element['dcube'],
                                 element['cell_vectors']['a'],
                                 element['cell_vectors']['b'],
                                 element['cell_vectors']['c'],
                                 element['centre'],
                                 max_voxels=max_voxels,
                                 min_voxels=min_voxels,
                                 make_cubic=False)
            ccube, (xmin, ymin, zmin), (xmax, ymax, zmax) = out
            print("running cubesliceplane")
            corners, corners_xy, gvalues_xy = cubesliceplane(
                ccube, (xmin, xmax, ymin, ymax, zmin, zmax),
                scentre,
                snormal,
                cell_size=cell_size,
                bval=el_bval,
                orientation=orientation,
                alter_bbox=alter_bbox,
                angle_step=angle_step,
                dist_tol=dist_tol)
            x, y, z = gvalues_xy.T
            cmap = get_cmap(el_cmap)
        elif element["type"] == "repeat_cell":
            cbar_title = "{0} ({1})".format(element["name"], "nuclei")
            centre = np.asarray(element['centre'], dtype=float)
            v1 = np.asarray(element['cell_vectors']['a'])
            v2 = np.asarray(element['cell_vectors']['b'])
            v3 = np.asarray(element['cell_vectors']['c'])
            bbox_pts = np.asarray([
                np.array([0.0, 0.0, 0.0]), v1, v2, v3, v1 + v2, v1 + v3,
                v1 + v2 + v3, v2 + v3
            ])
            bbox_x, bbox_y, bbox_z = bbox_pts.T
            xmin, xmax, ymin, ymax, zmin, zmax = (bbox_x.min(), bbox_x.max(),
                                                  bbox_y.min(), bbox_y.max(),
                                                  bbox_z.min(), bbox_z.max()
                                                  )  # l,r,bottom,top
            xmin, ymin, ymin = np.array(
                (xmin, ymin, ymin)) - 0.5 * (v1 + v2 + v3) + np.array(centre)
            xmax, ymax, zmax = np.array(
                (xmax, ymax, zmax)) - 0.5 * (v1 + v2 + v3) + np.array(centre)
            corners, corners_xy, gpoints, gpoints_xy = sliceplane_points(
                (xmin, xmax, ymin, ymax, zmin, zmax), scentre, snormal,
                cell_size, orientation, alter_bbox, angle_step, dist_tol)

            gvalues = np.full((gpoints_xy.shape[0], ), 0, dtype=np.float64)
            # create a map of site labels to color and index
            color_map = {(d[0], d[1]): i + 1
                         for i, d in enumerate(
                             sorted(
                                 set([(site["label"], site["color_fill"])
                                      for site in element["sites"]])))}
            for site in element["sites"]:
                mask = np.abs(np.linalg.norm(gpoints - site["ccoord"],
                                             axis=1)) < site["radius"]
                gvalues[mask] = color_map[(site["label"], site["color_fill"])]
            x, y = gpoints_xy.T
            z = gvalues

            # make cmap be correct for color_map
            v2colmap = {v: k[1] for k, v in color_map.items()}
            clist = ["white"] + [v2colmap[k] for k in sorted(v2colmap.keys())]
            cmap = LinearSegmentedColormap.from_list("cmap_name",
                                                     clist,
                                                     N=len(clist))
        else:
            continue

        el_vmin = np.nanmin(z) if el_vmin is None else el_vmin
        el_vmax = np.nanmax(z) if el_vmax is None else el_vmax

        cbar_fmt = ticker.FuncFormatter(fmt_scientific)
        min_exp, min_diff_exp = (2, 2)
        exp_min, exp_max = int('{:.2e}'.format(el_vmin).split('e')[1]), int(
            '{:.2e}'.format(el_vmax).split('e')[1])
        if abs(exp_min - exp_max) < min_diff_exp and abs(exp_min) > min_exp:
            el_multiplier = 10**float(exp_min)
            el_vmin /= el_multiplier
            el_vmax /= el_multiplier
            z /= el_multiplier
            cbar_title += r" $\times 10^{{{}}}$".format(int(exp_min))

        x_axis, x_arr = np.unique(x, return_inverse=True)
        y_axis, y_arr = np.unique(y, return_inverse=True)
        z_array = np.full((np.max(y_arr) + 1, np.max(x_arr) + 1), np.nan)
        z_array[y_arr, x_arr] = z

        print("plotting contour")
        if el_contourf:
            cset = ax.contourf(x_axis,
                               y_axis,
                               z_array,
                               cmap=cmap,
                               vmin=el_vmin,
                               vmax=el_vmax,
                               extend='both')
        else:
            cset = ax.contour(x_axis,
                              y_axis,
                              z_array,
                              cmap=cmap,
                              vmin=el_vmin,
                              vmax=el_vmax,
                              extend='both')

        norm = Normalize(vmin=el_vmin, vmax=el_vmax)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('bottom', size='5%', pad=0.3)
        # see https://matplotlib.org/devdocs/tutorials/colors/colorbar_only.html
        cbar = ColorbarBase(cax,
                            cmap=cmap,
                            norm=norm,
                            orientation="horizontal",
                            ticklocation="bottom",
                            extend="both",
                            format=cbar_fmt)
        cbar.set_label(cbar_title, fontsize=8)

        if element["type"] == "repeat_cell":
            v2labelmap = {v: k[0] for k, v in color_map.items()}
            cbar.set_ticks(sorted(v2labelmap.keys()))
            cbar.set_ticklabels(
                [v2labelmap[k] for k in sorted(v2labelmap.keys())])

        cax.tick_params(labelsize=8)
        for tick in cax.get_xticklabels():
            tick.set_rotation(cbar_tick_rot)

        if el_corners:
            crnrs = [c for c in zip(corners_xy, corners)]
            for cxy, c3d in [crnrs[2], crnrs[0], crnrs[3], crnrs[1]]:
                ax.scatter([cxy[0]], [cxy[1]],
                           label="({0:.2f}, {1:.2f}, {2:.2f})".format(*c3d),
                           edgecolor='black')
            # ax.legend(ncol=1, loc='center left', bbox_to_anchor=(1.0, 0.5), fontsize="x-small",
            #           title="Coordinate\nMapping")
            ax.legend(ncol=2,
                      loc='lower center',
                      fontsize="x-small",
                      bbox_to_anchor=(0.5, 1.0),
                      title="Coordinate Mapping",
                      framealpha=0.5)

        final_axes.append((ax, cax))

    return fig, final_axes
示例#13
0
def plotter(fdict):
    """ Go """
    pgconn = get_dbconn("asos")

    ctx = get_autoplot_context(fdict, get_description())
    station = ctx["zstation"]
    date = ctx["date"]
    opt = ctx["opt"]
    varname = ctx["v"]

    tzname = ctx["_nt"].sts[station]["tzname"]

    # Resolve how to limit the query data
    limiter = ""
    if opt == "day":
        limiter = (f" and to_char(valid at time zone '{tzname}', 'mmdd') = "
                   f"'{date.strftime('%m%d')}' ")
        subtitle = (f"For Date of {date.strftime('%-d %b')}, "
                    f"{date.strftime('%-d %b %Y')} plotted in bottom panel")
        datefmt = "%I %p"
    elif opt == "week":
        limiter = f" and extract(week from valid) = {date.strftime('%V')} "
        subtitle = (
            f"For ISO Week of {date.strftime('%V')}, "
            f"week of {date.strftime('%-d %b %Y')} plotted in bottom panel")
        datefmt = "%-d %b"
    elif opt == "month":
        limiter = f" and extract(month from valid) = {date.strftime('%m')} "
        subtitle = (f"For Month of {date.strftime('%B')}, "
                    f"{date.strftime('%b %Y')} plotted in bottom panel")
        datefmt = "%-d"
    else:
        subtitle = f"All Year, {date.year} plotted in bottom panel"
        datefmt = "%-d %b"

    # Load up all the values, since we need pandas to do some heavy lifting
    obsdf = read_sql(
        f"""
        select valid at time zone 'UTC' as utc_valid,
        extract(year from valid at time zone %s)  as year,
        extract(hour from valid at time zone %s +
            '10 minutes'::interval)::int as hr, {varname}
        from alldata WHERE station = %s and {varname} is not null {limiter}
        and report_type = 2 ORDER by valid ASC
    """,
        pgconn,
        params=(tzname, tzname, station),
        index_col=None,
    )
    if obsdf.empty:
        raise NoDataFound("No data was found.")

    # Assign percentiles
    obsdf["quantile"] = obsdf[["hr", varname]].groupby("hr").rank(pct=True)
    # Compute actual percentiles
    qtile = (obsdf[["hr", varname
                    ]].groupby("hr").quantile(np.arange(0, 1.01,
                                                        0.05)).reset_index())
    qtile = qtile.rename(columns={"level_1": "quantile"})
    (fig, ax) = plt.subplots(2, 1)
    cmap = get_cmap(ctx["cmap"])
    for hr, gdf in qtile.groupby("hr"):
        ax[0].plot(
            gdf["quantile"].values * 100.0,
            gdf[varname].values,
            color=cmap(hr / 23.0),
            label=str(hr),
        )
    ax[0].set_xlim(0, 100)
    ax[0].grid(True)
    ax[0].set_ylabel(PDICT[varname])
    ax[0].set_xlabel("Percentile")
    ax[0].set_position([0.13, 0.55, 0.71, 0.34])
    cax = plt.axes([0.86, 0.55, 0.03, 0.33],
                   frameon=False,
                   yticks=[],
                   xticks=[])
    cb = ColorbarBase(cax, cmap=cmap)
    cb.set_ticks(np.arange(0, 1, 4.0 / 24.0))
    cb.set_ticklabels(["Mid", "4 AM", "8 AM", "Noon", "4 PM", "8 PM"])
    cb.set_label("Local Hour")

    thisyear = obsdf[obsdf["year"] == date.year]
    if not thisyear.empty:
        ax[1].plot(thisyear["utc_valid"].values,
                   thisyear["quantile"].values * 100.0)
        ax[1].grid(True)
        ax[1].set_ylabel("Percentile")
        ax[1].set_ylim(-1, 101)
        ax[1].xaxis.set_major_formatter(
            mdates.DateFormatter(datefmt, tz=pytz.timezone(tzname)))
        if opt == "day":
            ax[1].set_xlabel(f"Timezone: {tzname}")
    title = ("%s %s %s Percentiles\n%s") % (
        station,
        ctx["_nt"].sts[station]["name"],
        PDICT[varname],
        subtitle,
    )
    fitbox(fig, title, 0.01, 0.99, 0.91, 0.99, ha="center", va="center")
    return fig, qtile
示例#14
0
def networkx_draw(G, pos, myfigsize, which_nodes, mytitle, with_names, n_edges,
                  edges2draw, edge_colors2draw, edge_widths2draw,
                  edgesNOT2draw, edge_colorsNOT2draw, edge_widthsNOT2draw,
                  edge_line_width, nodes2draw, nodesNOT2draw, node_colors2draw,
                  node_sizes2draw, node_colorsNOT2draw, node_sizesNOT2draw,
                  node_line_width, cmap_edge, norm_edge, cmap_node, norm_node,
                  clim_node, nodeCmap, nodeCmap_ticklabels, print_node_names,
                  print_edge_names, font_size, font_color, nodes_as_pies):

    _positions = pd.DataFrame.from_dict(pos)
    pos_min_lims = _positions.min(axis=1).values
    pos_max_lims = _positions.max(axis=1).values

    # ---- DRAW ---- #
    # init figure and subplots for colorbars
    myfig = plt.figure(figsize=myfigsize)

    # grid size
    r = 500
    c = 500
    s = 10
    ax1 = plt.subplot2grid((r, c), (0, 0), colspan=c - s, rowspan=r - s)
    ax2 = plt.subplot2grid((r, c), (0, c - s), colspan=s, rowspan=r - s)
    ax3 = plt.subplot2grid((r, c), (r - s, 0), colspan=c - s, rowspan=s)

    # set axis lims (because we will draw first the edges and then the nodes,
    # otherwise we don't need to set the lims)
    if nodes_as_pies:
        a_size = (pos_max_lims.max() - pos_min_lims.min()) / 10
        a2 = a_size / 2.0
    else:
        a2 = 0.2
    ax1.set_xlim(pos_min_lims[0] - a2, pos_max_lims[0] + a2)
    ax1.set_ylim(pos_min_lims[1] - a2, pos_max_lims[1] + a2)

    # draw edges
    if n_edges > 0:
        # draw edges
        if edge_line_width < 0:
            if which_nodes is not None:
                edges = nx.draw_networkx_edges(G,
                                               pos,
                                               edgelist=edgesNOT2draw,
                                               width=edge_widthsNOT2draw *
                                               abs(edge_line_width),
                                               edge_color=edge_colorsNOT2draw,
                                               alpha=0.25,
                                               ax=ax1)
            edges = nx.draw_networkx_edges(G,
                                           pos,
                                           edgelist=edges2draw,
                                           width=edge_widths2draw *
                                           abs(edge_line_width),
                                           edge_color=edge_colors2draw,
                                           ax=ax1)
        else:
            if which_nodes is not None:
                edges = nx.draw_networkx_edges(G,
                                               pos,
                                               edgelist=edgesNOT2draw,
                                               width=edge_line_width / 2,
                                               edge_color='grey',
                                               alpha=0.25,
                                               ax=ax1)
            edges = nx.draw_networkx_edges(G,
                                           pos,
                                           edgelist=edges2draw,
                                           width=edge_line_width,
                                           edge_color=edge_colors2draw,
                                           ax=ax1)

        if print_edge_names:
            edge_labels = {
                i[0:2]: '{}'.format(i[2]['weight'])
                for i in edges2draw
            }
            edge_label_handles = nx.draw_networkx_edge_labels(
                G, pos, edge_labels=edge_labels, ax=ax1)
            [
                label.set_bbox(dict(facecolor='none', edgecolor='none'))
                for label in edge_label_handles.values()
            ]

        # plot colorbar for edges in the right vertical axis
        bar = ColorbarBase(cmap=cmap_edge, norm=norm_edge, ax=ax2)
        bar.ax.tick_params(labelsize=myfigsize[0])
    else:
        ax2.axis('off')

    if nodes_as_pies:

        # draw nodes labels
        if print_node_names:
            if with_names is None:
                node_labels = {e: str(e) for e in sorted(G.nodes())}
            else:
                node_labels = {e: with_names[e] for e in sorted(nodes2draw)}
            # logger.info('node_labels'+str(node_labels))

        # transform from your data to your display coordinate system
        trans = ax1.transData.transform
        # inverted() transform from display to figure coordinate system
        trans2 = myfig.transFigure.inverted().transform

        wedge_colors = [cmap_node(i) for i in range(node_colors2draw.shape[1])]

        if which_nodes is not None:
            for n in nodesNOT2draw:
                node_s = node_sizesNOT2draw[n]

                xx, yy = trans(pos[n])  # figure coordinates
                xa, ya = trans2((xx, yy))  # axes coordinates
                a = plt.axes([xa - a2, ya - a2, a_size, a_size])
                # a.set_aspect('equal')
                node_c = node_colorsNOT2draw[n, :]
                if node_line_width > 0:
                    a.pie(node_c,
                          radius=node_s,
                          colors=['lightgrey'],
                          wedgeprops={
                              'linewidth': node_line_width,
                              'edgecolor': 'lightgrey'
                          })
                else:
                    a.pie(node_c, radius=node_s, colors=['lightgrey'])

        for n in nodes2draw:
            node_s = node_sizes2draw[n]

            xx, yy = trans(pos[n])  # figure coordinates
            xa, ya = trans2((xx, yy))  # axes coordinates
            a = plt.axes([xa - a2, ya - a2, a_size, a_size])
            # a.set_aspect('equal')
            node_c = node_colors2draw[n, :]
            if node_line_width > 0:
                wedgeprops = {
                    'linewidth': node_line_width,
                    'edgecolor': 'black'
                }
            else:
                wedgeprops = {}

            a.pie(node_c,
                  radius=node_s,
                  colors=wedge_colors,
                  wedgeprops=wedgeprops)
            if print_node_names:
                # transform from your data to your display coordinate system
                # trans3=a.transData.transform
                # inverted() transform from display to axes coordinate system
                # trans4=a.transAxes.inverted().transform
                # logger.info('data coordinates '+str(xx,yy+node_s))
                # dx, dy = trans3((xx,yy+node_s))
                # logger.info('display coordinates '+str([dx, dy]))
                # tx, ty = trans4((dx, dy))
                # logger.info('axes coordinates '+str([tx, ty]))
                a.set_title(node_labels[n],
                            x=0.5,
                            y=0.5,
                            fontdict={
                                'fontsize': font_size,
                                'fontweight': abs(edge_line_width),
                                'color': font_color,
                                'weight': 'bold',
                                'alpha': 0.8
                            })
    else:
        # draw nodes
        if which_nodes is not None:
            nodes = nx.draw_networkx_nodes(G,
                                           pos,
                                           nodelist=nodesNOT2draw,
                                           node_color='lightgrey',
                                           node_size=node_sizesNOT2draw,
                                           linewidths=0.1,
                                           ax=ax1)
            if nodes is not None:
                nodes.set_edgecolor('grey')

        nodes = nx.draw_networkx_nodes(G,
                                       pos,
                                       nodelist=nodes2draw,
                                       node_color=node_colors2draw,
                                       node_size=node_sizes2draw,
                                       linewidths=node_line_width,
                                       ax=ax1)
        if node_line_width > 0:
            nodes.set_edgecolor('black')

        # draw nodes labels
        if print_node_names:
            if with_names is None:
                node_labels = {e: str(e) for e in sorted(G.nodes())}
                node_label_handles = nx.draw_networkx_labels(
                    G,
                    pos,
                    labels=node_labels,
                    font_size=font_size,
                    font_color=font_color,
                    font_weight='bold',
                    ax=ax1)
            else:
                node_labels = {
                    e: with_names[i]
                    for i, e in enumerate(sorted(nodes2draw))
                }
                node_label_handles = nx.draw_networkx_labels(
                    G,
                    pos,
                    labels=node_labels,
                    font_size=font_size,
                    font_color=font_color,
                    font_weight='bold',
                    ax=ax1)

    # plot colorbar for nodes in the bottom horizontal axis
    if nodeCmap == 'btwn':
        bar = ColorbarBase(
            cmap=cmap_node,
            norm=norm_node,
            # ticks=np.arange(clim_node[0],clim_node[1]+1,1),
            # np.arange(clim_node[0],clim_node[1]+1,1)
            orientation='horizontal',
            ax=ax3)
    else:
        bar = ColorbarBase(cmap=cmap_node,
                           norm=norm_node,
                           ticks=np.arange(clim_node[0], clim_node[1] + 1, 1),
                           orientation='horizontal',
                           ax=ax3)
    if nodeCmap_ticklabels is not None:
        bar.set_ticklabels(nodeCmap_ticklabels, update_ticks=True)
    bar.ax.tick_params(labelsize=myfigsize[0])

    ax1.axis('off')
    ax1.set_title(mytitle, fontsize=font_size)

    return myfig
示例#15
0
class Contact_Plot(QtGui.QWidget):
    def __init__(self, url=None):
        super().__init__()

        if url is None:
            ws_url = "ws://localhost:7777"
        else:
            ws_url = url

        self.setWindowTitle("Contact Impedance")
        self.ws_imp = WS_Imp(contact_plot=self, url=ws_url)
        self.timer_interval = 0.5
        self.ch_label = list()   

        gs_kw = dict(width_ratios=[30,1], height_ratios=[1])
        self.fig, (self.ax, self.ax2) = plt.subplots(1, 2, gridspec_kw=gs_kw)

        self.canvas = FigureCanvas(self.fig)

        self.pos = list(channel_dict_2D.values())                        # get all X,Y values
        self.ch_names_ = list(channel_dict_2D.keys())                    # get all channel's names
        

        self.pos, self.outlines = topomap._check_outlines(self.pos, 'head')        # from mne.viz libs, normalize the pos

        self.cm = plt.cm.get_cmap('RdYlGn_r')
        self.norm = mpl.colors.Normalize(vmin=0, vmax=2000)
        self.colorbar = ColorbarBase(self.ax2, cmap=self.cm, norm=self.norm, ticks=[0, 500, 1000, 1500, 2000])
        self.colorbar.set_ticklabels(['0 KOhm','500', '1000', '1500', '2000'])
        while len(self.ch_label) is 0:
            pass

        self.plt_idx = [self.ch_names_.index(name) for name in self.ch_label]      # get the index of those required channels 
        self.ch_table = QtGui.QTableWidget(len(self.ch_label), 2, parent=self)
        item = QtGui.QTableWidgetItem("Channel")
        self.ch_table.setHorizontalHeaderItem (0, item)
        item = QtGui.QTableWidgetItem("Impedance (KOhm)")
        self.ch_table.setHorizontalHeaderItem (1, item)
        self.draw(self.ch_label, [0]*len(self.ch_label))
        header = self.ch_table.horizontalHeader()
        header.setResizeMode(QtGui.QHeaderView.ResizeToContents)
        header.setStretchLastSection(True)

        hlayout = QtGui.QHBoxLayout(self)
        hlayout.addWidget(self.canvas)
        hlayout.addWidget(self.ch_table)

        self.setup_signal_handler()
        self.show()
        self.value = 0

    def update_plot(self):
        if self.ws_imp.impedance_data:
            self.ax.cla()
            self.draw(self.ch_label, self.ws_imp.impedance_data.pop(0))
            self.fig.canvas.draw()
        else:
            pass

    def draw(self, ch_label, values):
        self.plt_idx = [self.ch_names_.index(name) for name in self.ch_label]      # get the index of those required channels 

        topomap._draw_outlines(self.ax, self.outlines)                        # from mne.viz libs, draw "head lines" on ax

        self.ax.scatter(self.pos[self.plt_idx,0], self.pos[self.plt_idx,1], c=values, marker='o', alpha=0.7, edgecolors=(0,0,0), cmap=self.cm, norm=self.norm)
        for idx in self.plt_idx:
            if self.pos[idx,0]<0:
                self.ax.text(self.pos[idx,0]-0.005, self.pos[idx,1]-0.04, self.ch_names_[idx], fontsize=10, horizontalalignment='center')
            elif self.pos[idx,0]>0:
                self.ax.text(self.pos[idx,0]+0.005, self.pos[idx,1]-0.04, self.ch_names_[idx], fontsize=10, horizontalalignment='center')
            else:
                self.ax.text(self.pos[idx,0]+0, self.pos[idx,1]-0.04, self.ch_names_[idx], fontsize=10, horizontalalignment='center')

        for idx, (ch_name, value) in enumerate(zip(self.ch_label, values)):
            item = QtGui.QTableWidgetItem(ch_name)
            self.ch_table.setItem(idx, 0, item)
            item = QtGui.QTableWidgetItem("{:.3f}".format(value))
            self.ch_table.setItem(idx, 1, item)

        self.ax.axis("off")

    def setup_signal_handler(self):
        self.timer = QtCore.QTimer()
        self.timer.setInterval(self.timer_interval*1000)
        self.timer.timeout.connect(self.update_plot)
        self.timer.start()
示例#16
0
def main(traveltime=False, alerttime=False, blindzone=False,
         bzreduction=False, optimalbz=False, txt_fontsize=14):
    lonmin, lonmax, latmin, latmax = 2.5, 37.5, 35.0, 48.0
    dlat = 2.0
    dlon = 2.0
    meridians = np.arange(2.0, 40, 4.0)
    parallels = np.arange(33, 48, 2.0)
    cmapname = 'RdBu_r'
    cmap = cm.get_cmap(cmapname)
    cmap.set_over('grey')
    extend = 'max'
    dlevel = 0.5
    datadir = './data/'
    lbl_fontsize = 16

    cb_label = []
    if traveltime:
        vmin = 0.
        vmax = 15.
        # fout = 'plots/travel_time_maps_reakt.png'
        fout = 'plots/travel_time_maps_reakt.pdf'
        cb_label.append('P-wave travel time to %d stations [s]' % 6)
        cb_label.append('Approximate inter-station distance [km]')

    if alerttime:
        vmin = 10.
        vmax = 60.
        # fout = 'plots/alert_time_maps_reakt.png'
        fout = 'plots/alert_time_maps_reakt.pdf'
        cb_label.append('Initial alert time [s]')

    if blindzone:
        vmin = 10.
        vmax = 200.
        cb_int = 15
        # fout = 'plots/blind_zone_maps_reakt.png'
        fout = 'plots/blind_zone_maps_reakt.pdf'
        cb_label.append('No-warning zone radius [km]')
        cb_label.append('Alert delay ($\Delta t_{alert}$) [s]')
        cb_label.append('Magnitude with positive EEW zone')

    if bzreduction:
        cmap = cm.get_cmap('RdBu')
        cmap.set_over('grey')
        vmin = 10
        vmax = 200
        cb_int = 15
        # fout = 'plots/blind_zone_reduction_reakt.png'
        fout = 'plots/blind_zone_reduction_reakt.pdf'
        cb_label.append('Blind zone reduction [km]')
        cb_label.append('Alert delay reduction [s]')

    if optimalbz:
        vmin = 10
        vmax = 200
        cb_int = 15
        # fout = 'plots/optimal_blind_zone_reakt.png'
        fout = 'plots/optimal_blind_zone_reakt.pdf'
        cb_label.append('Optimal no-warning zone radius [km]')
        cb_label.append('Optimal alert delay ($\Delta t_{alert}$) [s]')
        cb_label.append('Optimal magnitude with positive EEW zone')

    # get the damage zone, that is the zone for which intensity is >= 5.0
    # for magnitudes >= 5.0
    mags, dz = damage_zone()

    fig = plt.figure(figsize=(16, 7))
    # without Iceland
    # fig = plt.figure(figsize=(10, 7))
    ax = fig.add_axes([0.05, 0., .8, 1.0])

    # setup albers equal area conic basemap
    # lat_1 is first standard parallel.
    # lat_2 is second standard parallel.
    # lon_0,lat_0 is central point.
    if True:
        m = Basemap(width=5000000, height=2300000,
                    resolution='l', projection='aea', \
                    lat_1=35., lat_2=48, lon_0=20, lat_0=42, ax=ax)

        # without Iceland
        # m = Basemap(width=3000000, height=1800000,
        #            resolution='l', projection='aea', \
        #            lat_1=35., lat_2=48, lon_0=20, lat_0=42, ax=ax)
        m.drawmeridians(meridians, labels=[0, 0, 0, 1], color='lightgray',
                        linewidth=0.5, zorder=0, fontsize=lbl_fontsize)
        m.drawparallels(parallels, labels=[1, 0, 0, 0], color='lightgray',
                        linewidth=0.5, zorder=0, fontsize=lbl_fontsize)
        m.drawcoastlines(zorder=2)
        m.drawcountries(linewidth=1.0, zorder=2)
        m.fillcontinents('lightgray', zorder=0)

        if traveltime:
            # plot Romanian data
            resultsfn = os.path.join(datadir, 'ptt_ro_6stations.npz')
            plot_ptt(resultsfn, m, cmap, dlevel, vmax=vmax)
            xtxt, ytxt = m(25.0, 48.8)
            ax.text(xtxt, ytxt, 'Romania', fontsize=txt_fontsize,
                    horizontalalignment='center')

            # plot Greek data
            resultsfn = os.path.join(datadir, 'ptt_gr_6stations.npz')
            plot_ptt(resultsfn, m, cmap, dlevel, vmax=vmax)
            xtxt, ytxt = m(19.5, 36.5)
            ax.text(xtxt, ytxt, 'Greece', fontsize=txt_fontsize,
                    horizontalalignment='center')


            # plot Swiss data
            resultsfn = os.path.join(datadir, 'ptt_ch_6stations.npz')
            plot_ptt(resultsfn, m, cmap, dlevel, vmax=vmax)
            xtxt, ytxt = m(8, 48.75)
            ax.text(xtxt, ytxt, 'Switzerland', fontsize=txt_fontsize,
                    horizontalalignment='center')

            # plot Turkish data
            resultsfn = os.path.join(datadir, 'ptt_tr_6stations.npz')
            plot_ptt(resultsfn, m, cmap, dlevel, vmax=vmax)
            xtxt, ytxt = m(34.5, 39.5)
            ax.text(xtxt, ytxt, 'Turkey', fontsize=txt_fontsize,
                    horizontalalignment='center')

        if alerttime or blindzone or bzreduction or optimalbz:
            if bzreduction or optimalbz:
                blindzone = True
            # plot Romanian data
            resultsfn = os.path.join(datadir, 'event_list_ro.csv')
            optbz = os.path.join(datadir, 'optimal_blindzone_ro.npz')
            plot_at(resultsfn, m, cmap, vmin=vmin, vmax=vmax,
                    blindzone=blindzone, optbz=optbz, bzreduction=bzreduction,
                    optimalbz=optimalbz)

            # plot Greek data
            resultsfn = os.path.join(datadir, 'event_list_gr.csv')
            optbz = os.path.join(datadir, 'optimal_blindzone_gr.npz')
            plot_at(resultsfn, m, cmap, vmin=vmin, vmax=vmax,
                    blindzone=blindzone, optbz=optbz, bzreduction=bzreduction,
                    optimalbz=optimalbz)

            # plot Swiss data
            resultsfn = os.path.join(datadir, 'event_list_ch.csv')
            optbz = os.path.join(datadir, 'optimal_blindzone_ch.npz')
            plot_at(resultsfn, m, cmap, vmin=vmin, vmax=vmax,
                    blindzone=blindzone, optbz=optbz, bzreduction=bzreduction,
                    optimalbz=optimalbz)

            # plot Turkish data
            resultsfn = os.path.join(datadir, 'event_list_tr.csv')
            optbz = os.path.join(datadir, 'optimal_blindzone_tr.npz')
            plot_at(resultsfn, m, cmap, vmin=vmin, vmax=vmax,
                    blindzone=blindzone, optbz=optbz, bzreduction=bzreduction,
                    optimalbz=optimalbz)

    darkgray = (107 / 255., 107 / 255., 107 / 255.)
    # add a panel for California
    if True:
        # ax_cal = fig.add_axes([0.62, 0., .31, 1.0])
        ax_cal = fig.add_axes([0.5, 0.53, .45, .35])
        mca = Basemap(width=700000, height=750000,
                    resolution='l', projection='aea', \
                    lat_1=31, lat_2=38., lon_0=-117.5, lat_0=34.5,
                    ax=ax_cal)
        # mca = Basemap(projection='merc', llcrnrlat=31, urcrnrlat=37.5,
        #              llcrnrlon=-121.5, urcrnrlon=-114, lat_ts=34.5,
        #              resolution='i', ax=ax_cal)
        mca.drawcoastlines(zorder=2)
        mca.drawcountries(zorder=2)
        mca.drawstates(zorder=2)
        mca.fillcontinents(color=darkgray, zorder=0)
        mca.drawmeridians([-119, -115], labels=[0, 0, 1, 0],
                          color='lightgray', linewidth=0.5, zorder=0,
                          fontsize=lbl_fontsize)
        mca.drawparallels([32, 34, 36], labels=[0, 1, 0, 0],
                          color='lightgray', linewidth=0.5, zorder=0,
                          fontsize=lbl_fontsize)
        if traveltime:
            resultsfn = os.path.join(datadir, 'ptt_ca_6stations.npz')
            plot_ptt(resultsfn, mca, cmap, dlevel, vmin=vmin, vmax=vmax)
            xtxt, ytxt = mca(-121.25, 37.3)
            ax_cal.text(xtxt, ytxt, 'southern California', fontsize=txt_fontsize,
                        bbox=dict(facecolor='#eeeeee', alpha=1.0))
        if alerttime or blindzone or optimalbz:
            if bzreduction or optimalbz:
                blindzone = True
            resultsfn = os.path.join(datadir, 'event_list_ca.csv')
            optbz = os.path.join(datadir, 'optimal_blindzone_ca.npz')
            plot_at(resultsfn, mca, cmap, vmin=vmin, vmax=vmax,
                    blindzone=blindzone, optbz=optbz, bzreduction=bzreduction,
                    optimalbz=optimalbz, ssize=50)

    # Create an inset for Iceland
    if True:
        ax_ice = fig.add_axes([0.05, 0.63, .2, .25])
        mi = Basemap(width=550000, height=580000,
                    resolution='l', projection='aea', \
                    lat_1=62.5, lat_2=68.5, lon_0=-19, lat_0=65,
                    ax=ax_ice)
        mi.drawcoastlines(zorder=2)
        mi.fillcontinents(color=darkgray, zorder=0)
        mi.drawmeridians(np.arange(-26, -12, 5), labels=[0, 0, 1, 0],
                        color='lightgray', linewidth=0.5, zorder=0,
                        fontsize=lbl_fontsize)
        mi.drawparallels(np.arange(60, 70, 2), labels=[0, 1, 0, 0],
                        color='lightgray', linewidth=0.5, zorder=0,
                        fontsize=lbl_fontsize)
        # plot Iceland data
        if traveltime:
            resultsfn = os.path.join(datadir, 'ptt_is_6stations.npz')
            plot_ptt(resultsfn, mi, cmap, dlevel, vmin=vmin, vmax=vmax)
            xtxt, ytxt = mi(-23.5, 67)
            ax_ice.text(xtxt, ytxt, 'Iceland', fontsize=txt_fontsize)
        if alerttime or blindzone or optimalbz:
            if bzreduction or optimalbz:
                blindzone = True
            resultsfn = os.path.join(datadir, 'event_list_iceland_all.csv')
            optbz = os.path.join(datadir, 'optimal_blindzone_is.npz')
            plot_at(resultsfn, mi, cmap, vmin=vmin, vmax=vmax,
                    blindzone=blindzone, optbz=optbz, bzreduction=bzreduction,
                    optimalbz=optimalbz)


    # Create an inset for New Zealand
    if True:
        ax_nz = fig.add_axes([-0.04, 0.12, .5, .45])
        mnz = Basemap(width=1300000, height=1700000,
                    resolution='l', projection='aea', \
                    lat_1=-50., lat_2=-32, lon_0=172, lat_0=-41,
                    ax=ax_nz)
        mnz.drawcoastlines(zorder=2)
        mnz.fillcontinents(color=darkgray, zorder=0)
        mnz.drawmeridians(np.arange(164, 182, 6), labels=[0, 0, 0, 1],
                        color='lightgray', linewidth=0.5, zorder=0,
                        fontsize=lbl_fontsize)
        mnz.drawparallels(np.arange(-51, -31, 2), labels=[1, 0, 0, 0],
                        color='lightgray', linewidth=0.5, zorder=0,
                        fontsize=lbl_fontsize)
        # plot NZ data
        if traveltime:
            resultsfn = os.path.join(datadir, 'ptt_nz_6stations.npz')
            plot_ptt(resultsfn, mnz, cmap, dlevel, vmin=vmin, vmax=vmax)
            xtxt, ytxt = mnz(165.5, -37)
            ax_nz.text(xtxt, ytxt, 'New Zealand', fontsize=txt_fontsize)
        if alerttime or blindzone or optimalbz:
            if bzreduction or optimalbz:
                blindzone = True
            resultsfn = os.path.join(datadir, 'event_list_nz.csv')
            optbz = os.path.join(datadir, 'optimal_blindzone_nz.npz')
            plot_at(resultsfn, mnz, cmap, vmin=vmin, vmax=vmax,
                    blindzone=blindzone, optbz=optbz, bzreduction=bzreduction,
                    optimalbz=optimalbz)

    cb_fontsize = 14
    lbl_fontsize = 14
    lbl_pad = 10
    cax = fig.add_axes([0.87, 0.1, 0.01, 0.8])
    if traveltime:
        cb = ColorbarBase(cax, cmap=cmap, norm=Normalize(vmin=vmin, vmax=vmax),
                          orientation='vertical', extend=extend)
        cb.set_label(cb_label[0], fontsize=cb_fontsize, labelpad=lbl_pad)
        cb.ax.tick_params(labelsize=lbl_fontsize)
        cax2 = fig.add_axes([.95, 0.1, 0.01, 0.8])
        cb2 = ColorbarBase(cax2, cmap=cmap, norm=Normalize(vmin=vmin, vmax=vmax),
                           orientation='vertical', extend=extend)
        cb2.set_label(cb_label[1], fontsize=cb_fontsize, labelpad=lbl_pad)
        cb2.set_ticks(np.arange(1.5, 16.5, 1.5))
        cb2.set_ticklabels(station_distance(np.arange(1.5, 16.5, 1.5), 8.0, 6.5, 6))
        cb2.ax.tick_params(labelsize=lbl_fontsize)
    if (blindzone or optimalbz) and not bzreduction:
        cb = ColorbarBase(cax, cmap=cmap, norm=LogNorm(vmin=vmin, vmax=vmax),
                          orientation='vertical', extend=extend)
        ticks = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200])
        cb.set_ticks(ticks)
        cb.set_ticklabels(ticks)
        cb.set_label(cb_label[0], fontsize=cb_fontsize, labelpad=lbl_pad)
        cb.ax.tick_params(labelsize=lbl_fontsize)
        cax2 = fig.add_axes([0.95, 0.1, 0.01, 0.8])
        cax3 = fig.add_axes([1.03, 0.1, 0.01, 0.8])
        cb2 = ColorbarBase(cax2, cmap=cmap, norm=LogNorm(vmin=vmin, vmax=vmax),
                           orientation='vertical', extend=extend)
        cb2.set_label(cb_label[1], fontsize=cb_fontsize, labelpad=lbl_pad)
        cb2.set_ticks(ticks)
        cb2.set_ticklabels([int(x / 3.5 + 0.5) for x in ticks])
        cb2.ax.tick_params(labelsize=lbl_fontsize)
        cb3 = ColorbarBase(cax3, cmap=cmap, norm=LogNorm(vmin=vmin, vmax=vmax),
                           orientation='vertical', extend=extend)
        tklbl = []
        for _bz in ticks:
            idx = np.argmin(np.abs(dz - np.sqrt(_bz * _bz + 64)))
            tklbl.append(mags[idx])
        cb3.set_label(cb_label[2], fontsize=cb_fontsize, labelpad=lbl_pad)
        cb3.set_ticks(ticks)
        cb3.set_ticklabels(tklbl)
        cb3.ax.tick_params(labelsize=lbl_fontsize)
    if bzreduction:
        cb = ColorbarBase(cax, cmap=cmap, norm=LogNorm(vmin=vmin, vmax=vmax),
                          orientation='vertical', extend=extend)
        ticks = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200])
        cb.set_ticks(ticks)
        cb.set_ticklabels(ticks)
        cb.set_label(cb_label[0], fontsize=cb_fontsize, labelpad=lbl_pad)
        cb.ax.tick_params(labelsize=lbl_fontsize)
        cax2 = fig.add_axes([.95, 0.1, 0.01, 0.8])
        cb2 = ColorbarBase(cax2, cmap=cmap, norm=LogNorm(vmin=vmin, vmax=vmax),
                           orientation='vertical', extend=extend)
        cb2.set_label(cb_label[1], fontsize=cb_fontsize, labelpad=lbl_pad)
        cb2.set_ticks(ticks)
        cb2.set_ticklabels([int(x / 6.5 + 0.5) for x in ticks])
        cb2.ax.tick_params(labelsize=lbl_fontsize)

    fig.savefig(fout, bbox_inches='tight')
    plt.show()
示例#17
0
def sign_plot(x,
              g=None,
              flat=False,
              labels=True,
              cmap=None,
              cbar_ax_bbox=None,
              ax=None,
              **kwargs):
    """
    Significance plot, a heatmap of p values (based on Seaborn).

    Parameters
    ----------
    x : array_like, ndarray or DataFrame
        If flat is False (default), x must be an array, any object exposing
        the array interface, containing p values. If flat is True, x must be
        a sign_array (returned by `scikit_posthocs.sign_array` function)

    g : array_like or ndarray, optional
        An array, any object exposing the array interface, containing
        group names.

    flat : bool, optional
        If `flat` is True, plots a significance array as a heatmap using
        seaborn. If `flat` is False (default), plots an array of p values.
        Non-flat mode is useful if you need to  differentiate significance
        levels visually. It is the preferred mode.

    labels : bool, optional
        Plot axes labels (default) or not.

    cmap : list, optional
        1) If flat is False (default):
        List consisting of five elements, that will be exported to
        ListedColormap method of matplotlib. First is for diagonal
        elements, second is for non-significant elements, third is for
        p < 0.001, fourth is for p < 0.01, fifth is for p < 0.05.

        2) If flat is True:
        List consisting of three elements, that will be exported to
        ListedColormap method of matplotlib. First is for diagonal
        elements, second is for non-significant elements, third is for
        significant ones.
        3) If not defined, default colormaps will be used.

    cbar_ax_bbox : list, optional
        Colorbar axes position rect [left, bottom, width, height] where
        all quantities are in fractions of figure width and height.
        Refer to `matplotlib.figure.Figure.add_axes` for more information.
        Default is [0.95, 0.35, 0.04, 0.3].

    ax : matplotlib Axes, optional
        Axes in which to draw the plot, otherwise use the currently-active Axes.

    kwargs : other keyword arguments
        Keyword arguments to be passed to seaborn heatmap method. These
        keyword args cannot be used: cbar, vmin, vmax, center.

    Returns
    -------
    Axes object with the heatmap (and ColorBase object of cbar if `flat` is set to
    False).

    Examples
    --------
    >>> x = np.array([[-1,  1,  1],
                      [ 1, -1,  0],
                      [ 1,  0, -1]])
    >>> ph.sign_plot(x, flat = True)

    """

    for key in ['cbar', 'vmin', 'vmax', 'center']:
        if key in kwargs:
            del kwargs[key]

    if isinstance(x, DataFrame):
        df = x.copy()
    else:
        x = np.array(x)
        g = g or np.arange(x.shape[0])
        df = DataFrame(x, index=g, columns=g)

    dtype = df.values.dtype

    if not np.issubdtype(dtype, np.integer) and flat:
        raise ValueError(
            "X should be a sign_array or DataFrame of integer values")
    elif not np.issubdtype(dtype, np.floating) and not flat:
        raise ValueError("X should be an array or DataFrame of float p values")

    if not cmap and flat:
        # format: diagonal, non-significant, significant
        cmap = ['1', '#fbd7d4', '#1a9641']
    elif not cmap and not flat:
        # format: diagonal, non-significant, p<0.001, p<0.01, p<0.05
        cmap = ['1', '#fbd7d4', '#005a32', '#238b45', '#a1d99b']

    if flat:
        g = heatmap(df,
                    vmin=-1,
                    vmax=1,
                    cmap=ListedColormap(cmap),
                    cbar=False,
                    ax=ax,
                    **kwargs)
        if not labels:
            g.set_xlabel('')
            g.set_ylabel('')
        return g

    else:
        df[(x <= 0.001) & (x >= 0)] = 1
        df[(x <= 0.01) & (x > 0.001)] = 2
        df[(x <= 0.05) & (x > 0.01)] = 3
        df[(x > 0.05)] = 0
        np.fill_diagonal(df.values, -1)

        if len(cmap) != 5:
            raise ValueError("Cmap list must contain 5 items")

        g = heatmap(df,
                    vmin=-1,
                    vmax=3,
                    cmap=ListedColormap(cmap),
                    center=1,
                    cbar=False,
                    ax=ax,
                    **kwargs)
        if not labels:
            g.set_xlabel('')
            g.set_ylabel('')

        cbar_ax = g.figure.add_axes(cbar_ax_bbox or [0.95, 0.35, 0.04, 0.3])
        cbar = ColorbarBase(cbar_ax,
                            cmap=ListedColormap(cmap[2:] + [cmap[1]]),
                            boundaries=[0, 1, 2, 3, 4])
        cbar.set_ticks(np.linspace(0.5, 3.5, 4))
        cbar.set_ticklabels(['p < 0.001', 'p < 0.01', 'p < 0.05', 'NS'])

        cbar.outline.set_linewidth(1)
        cbar.outline.set_edgecolor('0.5')
        cbar.ax.tick_params(size=0)

        return g, cbar
示例#18
0
class MplCanvas(MyMplCanvas):#,gui.QWidget):#(MyMplCanvas):
    """
    A class for displaying radar data in basic mode. In this mode, the width and height of plot are equal.

    Parameters 
    ----------
    title : string
        Plotting header label.
    colormap : ColorMap
        ColorMap object.

    Attributes
    ----------
    figurecanvas : FigureCanvas
        The canvas for display.
    zoomer : list
        Storing zoom windows.
    _zoomWindow : QRectF
        Storing current zoom window.
    origin : list
        Storing the coordinates for onPress event.
    var_ : dict
        Storing variables for display.
    AZIMUTH : boolean
        Flag for azimuth display.
    RANGE_RING : boolean
        Flag for RANGE_RING display.
    COLORBAR : boolean
        Flag for colorbar display.
    PICKER_LABEL : boolean
        Flag for picker label display.
    cb : ColorbarBase
        Colorbar object.
    cMap : ColorMap
        ColorMap object.
    pressEvent : event
        Press event.
    pressed : boolean
        Flag for press event.
    deltaX : float
        X change of rubberband. Zoom window only when the change is greater than ZOOM_WINDOW_PIXEL_LIMIT.
    deltaY : float
        Y change of rubberband.
    startX : float
        Rubberband start x value.
    startY : float
        Rubberband start y value.
    moveLabel : QLabel
        Picker label
    sweep : Sweep 
        Sweep object.
    ranges : list
        Sweep ranges
    varName : string
        Storing current display variable name.
    x : list
        Storing sweep x values.
    y : list
        Storing sweep y values.
    label : string
        Storing header label and sweep time stamp
    """

    def __init__(self, title, colormap, parent=None, width=3, height=3, dpi=100):
        self.fig = Figure()#plt.figure()#figsize=(width, height), dpi=dpi)
        plt.axis('off')
        self.axes = self.fig.add_subplot(111,aspect='equal')
        self.fig.set_dpi( dpi )
        self.headerLabel = title
        #self.axes.hold(False)
        #self.fig.canvas.mpl_connect('pick_event', self.onpick)

        self.figurecanvas = FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self,
                                   gui.QSizePolicy.Expanding,
                                   gui.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        self.setWindow(core.QRectF(-1. * RENDER_PIXELS/2., 1. * RENDER_PIXELS/2., 1. * RENDER_PIXELS, -1. * RENDER_PIXELS))
#        self.origins = core.QPoint()
        self.ignorePaint = False
        #self.bottomRight = core.QPoint()
        self.rubberBand = gui.QRubberBand(gui.QRubberBand.Rectangle, self)
        self.zoomer = []
#        self.picker = []
            
        self.origin = [RENDER_PIXELS,RENDER_PIXELS]
        self.scaleFactor = 1.0
#        self.offsetX = 0.0
#        self.offsetY = 0.0
        self.var_ = {}
        self.AZIMUTH = False
        self.RANGE_RING = False
        self.COLORBAR = True
        self.PICKER_LABEL = False
        self.cb = None
        self.cMap = colormap

        self.pressEvent = None
        self.pressed = False
        self.deltaX = 0.
        self.deltaY = 0.
        self.startX = None
        self.startY = None

        self.moveLabel = gui.QLabel("",self)
        self.moveLabel.setText("")
        self.moveLabel.hide()
        self.moveLabel.setStyleSheet("font-size:12px; margin:3px; padding:4px; background:#FFFFFF; border:2px solid #000;")

        self.mpl_connect('button_press_event', self.onPress)
        self.mpl_connect('button_release_event', self.onRelease)
        self.mpl_connect('motion_notify_event', self.onMove)

    def onPress(self,event):
        """ method called when mouse press"""
        if event.button == 1: ## left button
            xdata = event.xdata
            ydata = event.ydata
            # check if mouse is outside the figure        
            if xdata is None or ydata is None:
                return       

            self.pressed = True
            self.pressEvent = event

            self.origin = core.QPoint(event.x, self.height() - event.y)
            self.rubberBand.setGeometry(core.QRect(self.origin, core.QSize()))
            self.rubberBand.show()

            # start point        
            self.startX = xdata
            self.startY = ydata

        if event.button == 2: ## middle botton - zoom in the center
            pass
        if event.button == 3:
            pass

    def onMove(self,event):
        """ method called when mouse moves """
        xdata = event.xdata
        ydata = event.ydata
        if xdata is None or ydata is None:
            self.moveLabel.hide()
            return

        if self.pressed:  ## display rubberband
            if self.PICKER_LABEL:
                self.moveLabel.hide()

            deltaX = event.x - self.pressEvent.x  ## moved distance
            deltaY = event.y - self.pressEvent.y  ## for rubberband
            dx = dy = min(fabs(deltaX),fabs(deltaY))
            if deltaX<0: 
                dx = -dx
            if deltaY<0:
                dy = -dy
            newRect = core.QRect(self.origin.x(), self.origin.y(), int(dx), -int(dy))
            newRect = newRect.normalized()
            self.rubberBand.setGeometry(newRect)
            self.deltaX = dx
            self.deltaY = dy

        else:  ## display label
            if self.PICKER_LABEL:
                i,j = self.retrieve_z_value(xdata,ydata)
                self.moveLabel.show()
                if i is not None and j is not None:
#                    self.moveLabel.setText(core.QString(r"x=%g, y=%g, z=%g" % (xdata,ydata,self.var_[i][j]))) ## TODO: should use xdata or self.x[i][j]
                    self.moveLabel.setText(r"x=%g, y=%g, z=%g" % (xdata,ydata,self.var_[i][j])) ## TODO: should use xdata or self.x[i][j]
                    
                else:
#                    self.moveLabel.setText(core.QString(r"x=%g, y=%g, z=n/a" % (xdata,ydata)))
                    self.moveLabel.setText(r"x=%g, y=%g, z=n/a" % (xdata,ydata))
                self.moveLabel.adjustSize()
                offset = 10
                if self.width()-event.x < self.moveLabel.width():
                    offset = -10 - self.moveLabel.width()
                self.moveLabel.move(event.x+offset,self.height()-event.y)

    def retrieve_z_value(self, xdata, ydata):
        #xpos = np.argmin(np.abs(xdata-self.x))
        #ypos = np.argmin(np.abs(ydata-self.y))
        MIN = 99999
        iv = None
        jv = None
        for i in range(len(self.x)):
            j = self.findNearest(np.copy(self.x[i]),xdata)
            if j is not None:
                d = self.distance(xdata,ydata,self.x[i][j],self.y[i][j]) 
                if d < MIN:
                    iv = i
                    jv = j
                    MIN = d
        return iv,jv

    def onRelease(self,event):
        """ method called when mouse button is released """
        if event.button == 1:
            self.pressed = False
            self.rubberBand.hide()

            xdata = event.xdata ## mouse real position
            ydata = event.ydata
            if xdata is None or ydata is None or self.startX is None or self.startY is None:
                return

            d0 = self.width() * FIGURE_CANCAS_RATIO
            x_range = self.axes.get_xlim()[1]-self.axes.get_xlim()[0]
            y_range = self.axes.get_ylim()[1]-self.axes.get_ylim()[0]
            (x1,y1) = self.startX, self.startY
            (x2,y2) = x1 + self.deltaX/d0 * x_range, y1+self.deltaY/d0 * y_range

            oldRect = core.QRectF() # last rubberband rect
            oldRect.setLeft(self.axes.get_xlim()[0])
            oldRect.setRight(self.axes.get_xlim()[1])
            oldRect.setBottom(self.axes.get_ylim()[0])
            oldRect.setTop(self.axes.get_ylim()[1])

            rect = core.QRectF()  # current rubberband rect
            rect.setLeft(min(x1,x2))
            rect.setRight(max(x1,x2))
            rect.setBottom(min(y1,y2))
            rect.setTop(max(y1,y2))

            ## react only when draged region is greater than 0.01 times of old rect
            if fabs(self.deltaX)>ZOOM_WINDOW_PIXEL_LIMIT and \
               fabs(rect.width())>ZOOM_WINDOW_WIDTH_LIMIT and \
               fabs(rect.width()) >= 0.01*fabs(oldRect.width()): 
                self.zoomer.append(oldRect)
                self.zoomTo(rect)
                self._zoomWindow = rect

    def zoomTo(self,rect):
        """ adjust zoom winodw to rect """
        self.axes.set_xlim(rect.left(),rect.right())
        self.axes.set_ylim(rect.bottom(),rect.top())
        self.draw()

    def findNearest(self, array, target):
        """ find nearest value to target and return its index """
        diff = abs(array - target)
        mask = np.ma.greater(diff, 0.151) ## TODO: select a threshold (range:meters_between_gates = 150.000005960464)
        if np.all(mask):
            return None # returns None if target is greater than any value
        masked_diff = np.ma.masked_array(diff, mask)
        return masked_diff.argmin()
    
    def distance(self, x1, y1, x2, y2):
        """ calculate distance between two points """
        return sqrt((x1-x2)**2 + (y1-y2)**2) ## TODO: formula

    def sizeHint(self):
        w, h = self.get_width_height()
        return core.QSize(w, h)

    def minimumSizeHint(self):
        return core.QSize(10, 10)

    def setWindow(self, window):
        """ initialize the full window to use for this widget """
        self._zoomWindow = window
        self._aspectRatio = window.width() / window.height()

    def resizeEvent(self, event):
        """ method called when resize window """
        sz = event.size()
        width = sz.width()
        height = sz.height()
        dpival = self.fig.dpi
        winch = float(width)/dpival
        hinch = float(height)/dpival
        self.fig.set_size_inches( winch, hinch )
        #self.draw()
        #self.update()
        self.fig.canvas.draw()
        self.origin = [width,height]
        
    def drawSweep(self, sweep, varName, beamWidth):
        """ draw sweep """
        self.beamWidth = beamWidth
        self.ranges = sweep.ranges
        self.sweep = sweep
        self.varName = varName.lower()
        self.var_ = sweep.vars_[varName] #in list
        self.x = sweep.x
        self.y = sweep.y
        self.label = self.headerLabel + sweep.timeLabel
        self.update_figure() #update figure

    def update_figure(self):
        """ update figure - need to call it explicitly """
        if len(self.var_) > 0:
            self.axes.clear()
            # avoid missing values of -32768
            self.var_ = np.ma.array(self.var_, mask=(self.var_ < -32000))
            vmin = min(min(x) for x in self.var_)
            vmax = max(max(x) for x in self.var_)

            im = self.axes.pcolormesh(self.x,self.y,self.var_, vmin=vmin, vmax=vmax, cmap=self.cMap(self.varName)) 
            ## setup zeniths, azimuths, and colorbar
            if self.RANGE_RING:
                self.draw_range_ring()
            if self.AZIMUTH:
                self.draw_azimuth_line()
            if self.COLORBAR:
                self.draw_colorbar(im,vmin,vmax)
            #self.x[0:359]/1e3,self.y[0:359]/1e3,self.var_,vmin=vmin, vmax=vmax)

            #plt.axis('off') ## show x, y axes or not
            #self.adjustZoomWindow() ## zoomWindow will not change for different variable - keep using the current zoom window
            self.zoomTo(self._zoomWindow)
            self.axes.set_title(self.label, size=9) ## TODO: change size to be adaptive
            self.fig.canvas.draw()
            ## draw contour - a new feature - grayscale, no zoom in/out support
            ## self.axes.contour(self.x,self.y,self.var_,[0.5], linewidths=2., colors='k')
            #self.fig.canvas.blit(self.axes.bbox)

    def draw_azimuth_line(self):
        """ draw azimuths with 30-degree intervals """
        angles = np.arange(0, 360, 30)
        labels = [90,60,30,0,330,300,270,240,210,180,150,120]
        x = R * np.cos(np.pi*angles/180)
        y = R * np.sin(np.pi*angles/180)

        for xi,yi,ang,lb in zip(x,y,angles,labels):
            line = plt.Line2D([0,xi],[0,yi],linestyle='dashed',color='lightgray',lw=0.8)
            self.axes.add_line(line)
            xo,yo = 0,0
            if ang>90 and ang<180:
                xo = -10
                yo = 3
            elif ang == 180:
                xo = -15
                yo = -3
            elif ang>180 and ang<270:
                xo = -12
                yo = -10
            elif ang == 270:
                xo = -10
                yo = -8
            elif ang >270 and ang<360:
                yo = -5
            self.axes.annotate(str(lb), xy=(xi,yi), xycoords='data',
                               xytext=(xo,yo), textcoords='offset points',
                               arrowprops=None,size=10)

    def draw_range_ring(self):
        """ draw zeniths with 30 intervals """
        zeniths = np.arange(0,R+1,30)
        angle = 135.
        for r in zeniths:
            circ = plt.Circle((0, 0),radius=r,linestyle='dashed',color='lightgray',lw=0.8,fill=False)
            self.axes.add_patch(circ)
            x = R * np.cos(np.pi*angle/180.) * r/R
            y = R * np.sin(np.pi*angle/180.) * r/R
            print 'r=',r, x, y
            self.axes.annotate(int(r), xy=(x,y), xycoords='data', arrowprops=None,size=10)

    def draw_colorbar(self,im,vmin,vmax):
        """ draw colorbar """
        if self.cb:
            self.fig.delaxes(self.fig.axes[1])
            self.fig.subplots_adjust(right=0.90)

        pos = self.axes.get_position()
        l, b, w, h = pos.bounds
        cax = self.fig.add_axes([l, b-0.06, w, 0.03]) # colorbar axes
        cmap=self.cMap(self.varName)
        substName = self.varName
        if not self.cMap.ticks_label.has_key(self.varName):
            # we couldn't find 'vel_f', so try searching for 'vel'
            u = self.varName.find('_')
            if u:
                substName = self.varName[:u]
                if not self.cMap.ticks_label.has_key(substName):
                
                    msgBox = gui.QMessageBox()
                    msgBox.setText(
    """ Please define a color scale for '{0}' in your configuration file """.format(self.varName))
                    msgBox.exec_()
                    raise RuntimeError(
   """ Please define a color scale for '{0}' in your configuration file """.format(self.varName))
        bounds = self.cMap.ticks_label[substName]
        norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
        self.cb = ColorbarBase(cax, cmap=cmap, norm=norm,  orientation='horizontal',  boundaries=bounds,ticks=bounds)#, format='%1i')  ## spacing='proportional' -- divide proportionally by the value
        self.cb.ax.tick_params(labelsize=8) 
        #t = [str(int(i)) for i in bounds]
        t = [str(i) for i in bounds]
        self.cb.set_ticklabels(t,update_ticks=True)
        self.cb.set_label('Color Scale', size=8)

    def resetFactors(self):
        """ reset factors """
        self.zoomer = []
        self.setWindow(core.QRect(-1 * RENDER_PIXELS/2, 1 * RENDER_PIXELS/2, 1 * RENDER_PIXELS, 1 * RENDER_PIXELS))
#        self.update_figure()
        self.fig.canvas.draw()

    def changeZoomerPointer(self, ind=None):
        """ method called when mouse button is pressed, changing zoomer pointer """
        if ind is None:
            if len(self.zoomer)>0:
                zoomWindow = self.zoomer[-1]
                self.zoomTo(zoomWindow)
                self.zoomer.pop()
        else:
            if len(self.zoomer)>0:
                zoomWindow = self.zoomer[0]
                self.zoomTo(zoomWindow)
                self.zoomer=[]      
            
    def getAspectRatio(self):
        return self._aspectRatio

    def keyPressEvent(self, event):
        """ method called when key press """
        print 'RadialDisplay::keyPressEvent: ', event.key()
        if event.key() == core.Qt.Key_C:
            self.resetFactors()
            event.accept()

    '''
示例#19
0
def plot_iso_similarities(plt_clf, plt_mtype, plt_mcomb, pred_df, pheno_dict,
                          cdata, args):
    mtype_lbl = get_fancy_label(plt_mtype)

    sim_mcombs = {
        mcomb
        for mcomb in pred_df.index
        if (isinstance(mcomb, ExMcomb) and mcomb != plt_mcomb
            and len(mcomb.mtypes) == 1
            and not (mcomb.all_mtype & dict(cna_mtypes)['Shal']).is_empty() and
            (copy_mtype.is_supertype(tuple(mcomb.mtypes)[0]) or
             (len(tuple(mcomb.mtypes)[0].subkeys()) == 1 and not any(
                 'domain' in lvl
                 for lvl in tuple(mcomb.mtypes)[0].get_levels()))))
    }

    fig, (vio_ax, sim_ax, lgnd_ax) = plt.subplots(
        figsize=(2 + len(sim_mcombs), 6),
        nrows=1,
        ncols=3,
        gridspec_kw=dict(width_ratios=[1, len(sim_mcombs), 1]))

    vals_df = pd.DataFrame({
        'Value':
        pred_df.loc[plt_mcomb, cdata.get_train_samples()].apply(np.mean),
        'mStat':
        pheno_dict[plt_mtype],
        'dummy':
        0,
        'eStat':
        np.array(cdata.train_pheno(plt_mcomb.not_mtype))
    })

    sim_df = pd.concat([
        pd.DataFrame({
            'Mcomb':
            mcomb,
            'Value':
            pred_df.loc[mcomb,
                        cdata.get_train_samples()][pheno_dict[mcomb]].apply(
                            np.mean)
        }) for mcomb in sim_mcombs
    ])

    sns.violinplot(data=vals_df[~vals_df.mStat & ~vals_df.eStat],
                   x='dummy',
                   y='Value',
                   hue='eStat',
                   palette=[variant_clrs['WT']],
                   hue_order=[False, True],
                   split=True,
                   linewidth=0,
                   cut=0,
                   ax=vio_ax)
    sns.violinplot(data=vals_df[vals_df.mStat & ~vals_df.eStat],
                   x='dummy',
                   y='Value',
                   hue='eStat',
                   palette=[variant_clrs['Point']],
                   hue_order=[False, True],
                   split=True,
                   linewidth=0,
                   cut=0,
                   ax=vio_ax)

    vio_ax.set_xlim(-0.5, 0.01)
    for art in vio_ax.get_children()[:2]:
        art.set_alpha(0.41)

    vio_ax.set_yticks([])
    vio_ax.get_legend().remove()
    vio_ax.set_zorder(1)

    wt_mean = np.mean(vals_df.Value[~vals_df.mStat & ~vals_df.eStat])
    vio_ax.axhline(y=wt_mean,
                   xmin=0,
                   xmax=1.31 + len(sim_mcombs),
                   color=variant_clrs['WT'],
                   clip_on=False,
                   linestyle='--',
                   linewidth=1.7,
                   alpha=0.47)

    mut_mean = np.mean(vals_df.Value[vals_df.mStat & ~vals_df.eStat])
    vio_ax.axhline(y=mut_mean,
                   xmin=0,
                   xmax=1.31 + len(sim_mcombs),
                   color=variant_clrs['Point'],
                   clip_on=False,
                   linestyle='--',
                   linewidth=1.7,
                   alpha=0.47)

    vals_min, vals_max = pd.concat([vals_df, sim_df],
                                   sort=False).Value.quantile(q=[0, 1])
    vals_rng = (vals_max - vals_min) / 103
    plt_min = min(vals_min - vals_rng * 13, 2 * wt_mean - mut_mean)
    plt_max = max(vals_max + vals_rng, 2 * mut_mean - wt_mean)

    vio_ax.text(-0.52,
                wt_mean,
                "0",
                size=15,
                fontweight='bold',
                ha='right',
                va='center')
    vio_ax.text(-0.52,
                mut_mean,
                "1",
                size=15,
                fontweight='bold',
                ha='right',
                va='center')

    vio_ax.text(0.99,
                0,
                "Isolated\nClassification\n of "
                "{}\nMutations (M1)".format(mtype_lbl),
                size=13,
                fontstyle='italic',
                ha='right',
                va='top',
                transform=vio_ax.transAxes)

    mcomb_grps = sim_df.groupby('Mcomb')['Value']
    mcomb_scores = mcomb_grps.mean().sort_values(ascending=False) - wt_mean
    mcomb_scores /= (mut_mean - wt_mean)

    mcomb_mins = mcomb_grps.min()
    mcomb_maxs = mcomb_grps.max()
    mcomb_sizes = mcomb_grps.count()
    clr_norm = colors.Normalize(vmin=-1, vmax=2)

    sns.violinplot(data=sim_df,
                   x='Mcomb',
                   y='Value',
                   order=mcomb_scores.index,
                   palette=simil_cmap(clr_norm(mcomb_scores.values)),
                   saturation=1,
                   linewidth=10 / 7,
                   cut=0,
                   width=16 / 17,
                   ax=sim_ax)

    for i, (mcomb, scr) in enumerate(mcomb_scores.iteritems()):
        sim_ax.get_children()[i * 2].set_alpha(8 / 11)
        mcomb_lbl = get_fancy_label(tuple(mcomb.mtypes)[0], phrase_link='\n')

        sim_ax.text(i,
                    mcomb_mins[mcomb] - vals_rng,
                    "{}\n({} samples)".format(mcomb_lbl, mcomb_sizes[mcomb]),
                    size=7,
                    ha='center',
                    va='top')
        sim_ax.text(i,
                    mcomb_maxs[mcomb] + vals_rng / 2,
                    format(scr, '.2f'),
                    size=12,
                    fontweight='bold',
                    ha='center',
                    va='bottom')

    sim_ax.text(0.5,
                0,
                "<{}> Classifier\nScoring of Other "
                "Isolated\n{} Mutations (M2)".format(mtype_lbl, args.gene),
                size=13,
                fontstyle='italic',
                ha='center',
                va='top',
                transform=sim_ax.transAxes)

    clr_min = 2 * wt_mean - mut_mean
    clr_max = 2 * mut_mean - wt_mean
    clr_btm = (clr_min - plt_min) / (plt_max - plt_min)
    clr_top = (clr_max - plt_min) / (plt_max - plt_min)
    clr_rng = (clr_max - clr_min) * 1.38 / (plt_max - plt_min)
    clr_btm = clr_btm - (clr_top - clr_btm) * 0.19

    clr_ax = lgnd_ax.inset_axes(bounds=(0, clr_btm, 0.53, clr_rng))
    clr_bar = ColorbarBase(ax=clr_ax,
                           cmap=simil_cmap,
                           norm=clr_norm,
                           extend='both',
                           extendfrac=0.19,
                           ticks=[-0.5, 0, 0.5, 1.0, 1.5])

    clr_bar.set_ticklabels(
        ['M2 < WT', 'M2 = WT', 'WT < M2 < M1', 'M2 = M1', 'M2 > M1'])
    clr_ax.tick_params(labelsize=12)

    for ax in vio_ax, sim_ax, lgnd_ax:
        ax.set_ylim(plt_min, plt_max)
        ax.axis('off')

    plt.tight_layout(pad=0, w_pad=1.1)
    plt.savefig(os.path.join(plot_dir, args.gene,
                             "{}__iso-similarities.svg".format(args.cohort)),
                bbox_inches='tight',
                format='svg')

    plt.close()
示例#20
0
class BaseSlicer(object):
    """ The main purpose of these class is to have auto adjust of axes size
        to the data with different layout of cuts.
    """
    # This actually encodes the figsize for only one axe
    _default_figsize = [2.2, 2.6]
    _axes_class = CutAxes

    def __init__(self, cut_coords, axes=None, black_bg=False, **kwargs):
        """ Create 3 linked axes for plotting orthogonal cuts.

            Parameters
            ----------
            cut_coords: 3 tuple of ints
                The cut position, in world space.
            axes: matplotlib axes object, optional
                The axes that will be subdivided in 3.
            black_bg: boolean, optional
                If True, the background of the figure will be put to
                black. If you wish to save figures with a black background,
                you will need to pass "facecolor='k', edgecolor='k'" to 
                pylab's savefig.

        """
        self.cut_coords = cut_coords
        if axes is None:
            axes = plt.axes((0., 0., 1., 1.))
            axes.axis('off')
        self.frame_axes = axes
        axes.set_zorder(1)
        bb = axes.get_position()
        self.rect = (bb.x0, bb.y0, bb.x1, bb.y1)
        self._black_bg = black_bg
        self._colorbar = False
        self._colorbar_width = 0.05 * bb.width
        self._colorbar_margin = dict(left=0.25 * bb.width,
                                     right=0.02 * bb.width,
                                     top=0.05 * bb.height,
                                     bottom=0.05 * bb.height)
        self._init_axes(**kwargs)

    @staticmethod
    def find_cut_coords(img=None, threshold=None, cut_coords=None):
        # Implement this as a staticmethod or a classmethod when
        # subclassing
        raise NotImplementedError

    @classmethod
    def init_with_figure(cls, img, threshold=None,
                         cut_coords=None, figure=None, axes=None,
                         black_bg=False, leave_space=False, colorbar=False,
                         **kwargs):
        # deal with "fake" 4D images
        if img is not None and img is not False:
            img = _utils.check_niimg_3d(img)

        cut_coords = cls.find_cut_coords(img, threshold, cut_coords)

        if isinstance(axes, plt.Axes) and figure is None:
            figure = axes.figure

        if not isinstance(figure, plt.Figure):
            # Make sure that we have a figure
            figsize = cls._default_figsize[:]
            
            # Adjust for the number of axes
            figsize[0] *= len(cut_coords)
            
            # Make space for the colorbar
            if colorbar:
                figsize[0] += .7
                
            facecolor = 'k' if black_bg else 'w'

            if leave_space:
                figsize[0] += 3.4
            figure = plt.figure(figure, figsize=figsize,
                            facecolor=facecolor)
        if isinstance(axes, plt.Axes):
            assert axes.figure is figure, ("The axes passed are not "
                    "in the figure")

        if axes is None:
            axes = [0., 0., 1., 1.]
            if leave_space:
                axes = [0.3, 0, .7, 1.]
        if isinstance(axes, collections.Sequence):
            axes = figure.add_axes(axes)
        # People forget to turn their axis off, or to set the zorder, and
        # then they cannot see their slicer
        axes.axis('off')
        return cls(cut_coords, axes, black_bg, **kwargs)


    def title(self, text, x=0.01, y=0.99, size=15, color=None, bgcolor=None,
              alpha=1, **kwargs):
        """ Write a title to the view.

            Parameters
            ----------
            text: string
                The text of the title
            x: float, optional
                The horizontal position of the title on the frame in 
                fraction of the frame width.
            y: float, optional
                The vertical position of the title on the frame in 
                fraction of the frame height.
            size: integer, optional
                The size of the title text.
            color: matplotlib color specifier, optional
                The color of the font of the title.
            bgcolor: matplotlib color specifier, optional
                The color of the background of the title.
            alpha: float, optional
                The alpha value for the background.
            kwargs:
                Extra keyword arguments are passed to matplotlib's text
                function.
        """
        if color is None:
            color = 'k' if self._black_bg else 'w'
        if bgcolor is None:
            bgcolor = 'w' if self._black_bg else 'k'
        if hasattr(self, '_cut_displayed'):
            first_axe = self._cut_displayed[0]
        else:
            first_axe = self.cut_coords[0]
        ax = self.axes[first_axe].ax
        ax.text(x, y, text,
                transform=self.frame_axes.transAxes,
                horizontalalignment='left',
                verticalalignment='top',
                size=size, color=color,
                bbox=dict(boxstyle="square,pad=.3",
                          ec=bgcolor, fc=bgcolor, alpha=alpha),
                zorder=1000,
                **kwargs)
        ax.set_zorder(1000)


    def add_overlay(self, img, threshold=1e-6, colorbar=False, **kwargs):
        """ Plot a 3D map in all the views.

            Parameters
            -----------
            img: Niimg-like object
                See http://nilearn.github.io/building_blocks/manipulating_mr_images.html#niimg.
                If it is a masked array, only the non-masked part will be
                plotted.
            threshold : a number, None
                If None is given, the maps are not thresholded.
                If a number is given, it is used to threshold the maps:
                values below the threshold (in absolute value) are
                plotted as transparent.
            colorbar: boolean, optional
                If True, display a colorbar on the right of the plots.
            kwargs:
                Extra keyword arguments are passed to imshow.
        """
        if colorbar and self._colorbar:
            raise ValueError("This figure already has an overlay with a "
                             "colorbar.")
        else:
            self._colorbar = colorbar

        img = _utils.check_niimg_3d(img)

        if threshold is not None:
            data = img.get_data()
            if threshold == 0:
                data = np.ma.masked_equal(data, 0, copy=False)
            else:
                data = np.ma.masked_inside(data, -threshold, threshold,
                                           copy=False)
            img = new_img_like(img, data, img.get_affine())

        # Make sure that add_overlay shows consistent default behavior
        # with plot_stat_map
        kwargs.setdefault('interpolation', 'nearest')
        ims = self._map_show(img, type='imshow', **kwargs)

        if colorbar:
            self._colorbar_show(ims[0], threshold)

        plt.draw_if_interactive()

    def add_contours(self, img, **kwargs):
        """ Contour a 3D map in all the views.

            Parameters
            -----------
            img: Niimg-like object
                See http://nilearn.github.io/building_blocks/manipulating_mr_images.html#niimg.
                Provides image to plot.
            kwargs:
                Extra keyword arguments are passed to contour, see the
                documentation of pylab.contour
                Useful, arguments are typical "levels", which is a
                list of values to use for plotting a contour, and
                "colors", which is one color or a list of colors for
                these contours.
        """
        self._map_show(img, type='contour', **kwargs)
        plt.draw_if_interactive()

    def _map_show(self, img, type='imshow', resampling_interpolation='continuous', **kwargs):
        img = reorder_img(img, resample=resampling_interpolation)

        affine = img.get_affine()
        data = img.get_data()
        data_bounds = get_bounds(data.shape, affine)
        (xmin, xmax), (ymin, ymax), (zmin, zmax) = data_bounds

        xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \
                                        xmin, xmax, ymin, ymax, zmin, zmax

        if hasattr(data, 'mask') and isinstance(data.mask, np.ndarray):
            not_mask = np.logical_not(data.mask)
            xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \
                    get_mask_bounds(new_img_like(img, not_mask, affine))

        data_2d_list = []
        for display_ax in self.axes.values():
            try:
                data_2d = display_ax.transform_to_2d(data, affine)
            except IndexError:
                # We are cutting outside the indices of the data
                data_2d = None

            data_2d_list.append(data_2d)

        if 'vmin' not in kwargs:
            kwargs['vmin'] = min(d.min() for d in data_2d_list
                                 if d is not None)
        if 'vmax' not in kwargs:
            kwargs['vmax'] = max(d.max() for d in data_2d_list
                                 if d is not None)

        bounding_box = (xmin_, xmax_), (ymin_, ymax_), (zmin_, zmax_)

        ims = []
        to_iterate_over = zip(self.axes.values(), data_2d_list)
        for display_ax, data_2d in to_iterate_over:
            if data_2d is not None:
                im = display_ax.draw_2d(data_2d, data_bounds, bounding_box,
                                        type=type, **kwargs)
                ims.append(im)
        return ims

    def _colorbar_show(self, im, threshold):
        if threshold is None:
            offset = 0
        else:
            offset = threshold
        if offset > im.norm.vmax:
            offset = im.norm.vmax

        # create new  axis for the colorbar
        figure = self.frame_axes.figure
        _, y0, x1, y1 = self.rect
        height = y1 - y0
        x_adjusted_width = self._colorbar_width / len(self.axes)
        x_adjusted_margin = self._colorbar_margin['right'] / len(self.axes)
        lt_wid_top_ht = [x1 - (x_adjusted_width + x_adjusted_margin),
                         y0 + self._colorbar_margin['top'],
                         x_adjusted_width,
                         height - (self._colorbar_margin['top'] +
                                   self._colorbar_margin['bottom'])]
        self._colorbar_ax = figure.add_axes(lt_wid_top_ht, axis_bgcolor='w')

        our_cmap = im.cmap
        # edge case where the data has a single value
        # yields a cryptic matplotlib error message
        # when trying to plot the color bar
        nb_ticks = 5 if im.norm.vmin != im.norm.vmax else 1
        ticks = np.linspace(im.norm.vmin, im.norm.vmax, nb_ticks)
        bounds = np.linspace(im.norm.vmin, im.norm.vmax, our_cmap.N)

        # some colormap hacking
        cmaplist = [our_cmap(i) for i in range(our_cmap.N)]
        istart = int(im.norm(-offset, clip=True) * (our_cmap.N - 1))
        istop = int(im.norm(offset, clip=True) * (our_cmap.N - 1))
        for i in range(istart, istop):
            cmaplist[i] = (0.5, 0.5, 0.5, 1.)  # just an average gray color
        if im.norm.vmin == im.norm.vmax:  # len(np.unique(data)) == 1 ?
            return
        else:
            our_cmap = our_cmap.from_list('Custom cmap', cmaplist, our_cmap.N)

        self._cbar = ColorbarBase(
            self._colorbar_ax, ticks=ticks, norm=im.norm,
            orientation='vertical', cmap=our_cmap, boundaries=bounds,
            spacing='proportional')
        self._cbar.set_ticklabels(["%.2g" % t for t in ticks])

        self._colorbar_ax.yaxis.tick_left()
        tick_color = 'w' if self._black_bg else 'k'
        for tick in self._colorbar_ax.yaxis.get_ticklabels():
            tick.set_color(tick_color)
        self._colorbar_ax.yaxis.set_tick_params(width=0)

        self._cbar.update_ticks()

    def add_edges(self, img, color='r'):
        """ Plot the edges of a 3D map in all the views.

            Parameters
            -----------
            map: 3D ndarray
                The 3D map to be plotted. If it is a masked array, only
                the non-masked part will be plotted.
            affine: 4x4 ndarray
                The affine matrix giving the transformation from voxel
                indices to world space.
            color: matplotlib color: string or (r, g, b) value
                The color used to display the edge map
        """
        img = reorder_img(img)
        data = img.get_data()
        affine = img.get_affine()
        single_color_cmap = colors.ListedColormap([color])
        data_bounds = get_bounds(data.shape, img.get_affine())

        # For each ax, cut the data and plot it
        for display_ax in self.axes.values():
            try:
                data_2d = display_ax.transform_to_2d(data, affine)
                edge_mask = _edge_map(data_2d)
            except IndexError:
                # We are cutting outside the indices of the data
                continue
            display_ax.draw_2d(edge_mask, data_bounds, data_bounds,
                               type='imshow', cmap=single_color_cmap)

        plt.draw_if_interactive()

    def annotate(self, left_right=True, positions=True, size=12, **kwargs):
        """ Add annotations to the plot.

            Parameters
            ----------
            left_right: boolean, optional
                If left_right is True, annotations indicating which side
                is left and which side is right are drawn.
            positions: boolean, optional
                If positions is True, annotations indicating the
                positions of the cuts are drawn.
            size: integer, optional
                The size of the text used.
            kwargs:
                Extra keyword arguments are passed to matplotlib's text
                function.
        """
        kwargs = kwargs.copy()
        if not 'color' in kwargs:
            if self._black_bg:
                kwargs['color'] = 'w'
            else:
                kwargs['color'] = 'k'

        bg_color = ('k' if self._black_bg else 'w')
        if left_right:
            for display_ax in self.axes.values():
                display_ax.draw_left_right(size=size, bg_color=bg_color,
                                       **kwargs)

        if positions:
            for display_ax in self.axes.values():
                display_ax.draw_position(size=size, bg_color=bg_color,
                                       **kwargs)

    def close(self):
        """ Close the figure. This is necessary to avoid leaking memory.
        """
        plt.close(self.frame_axes.figure.number)

    def savefig(self, filename, dpi=None):
        """ Save the figure to a file

            Parameters
            ==========
            filename: string
                The file name to save to. It's extension determines the
                file type, typically '.png', '.svg' or '.pdf'.

            dpi: None or scalar
                The resolution in dots per inch.
        """
        facecolor = edgecolor = 'k' if self._black_bg else 'w'
        self.frame_axes.figure.savefig(filename, dpi=dpi,
                                       facecolor=facecolor,
                                       edgecolor=edgecolor)
示例#21
0
class MplCanvas(MyMplCanvas):  #,gui.QWidget):#(MyMplCanvas):
    """
    A class for displaying radar data in basic mode. In this mode, the width and height of plot are equal.

    Parameters 
    ----------
    title : string
        Plotting header label.
    colormap : ColorMap
        ColorMap object.

    Attributes
    ----------
    figurecanvas : FigureCanvas
        The canvas for display.
    zoomer : list
        Storing zoom windows.
    _zoomWindow : QRectF
        Storing current zoom window.
    origin : list
        Storing the coordinates for onPress event.
    var_ : dict
        Storing variables for display.
    AZIMUTH : boolean
        Flag for azimuth display.
    RANGE_RING : boolean
        Flag for RANGE_RING display.
    COLORBAR : boolean
        Flag for colorbar display.
    PICKER_LABEL : boolean
        Flag for picker label display.
    cb : ColorbarBase
        Colorbar object.
    cMap : ColorMap
        ColorMap object.
    pressEvent : event
        Press event.
    pressed : boolean
        Flag for press event.
    deltaX : float
        X change of rubberband. Zoom window only when the change is greater than ZOOM_WINDOW_PIXEL_LIMIT.
    deltaY : float
        Y change of rubberband.
    startX : float
        Rubberband start x value.
    startY : float
        Rubberband start y value.
    moveLabel : QLabel
        Picker label
    sweep : Sweep 
        Sweep object.
    ranges : list
        Sweep ranges
    varName : string
        Storing current display variable name.
    x : list
        Storing sweep x values.
    y : list
        Storing sweep y values.
    label : string
        Storing header label and sweep time stamp
    """
    def __init__(self,
                 title,
                 colormap,
                 parent=None,
                 width=3,
                 height=3,
                 dpi=100):
        self.fig = Figure()  #plt.figure()#figsize=(width, height), dpi=dpi)
        plt.axis('off')
        self.axes = self.fig.add_subplot(111, aspect='equal')
        self.fig.set_dpi(dpi)
        self.headerLabel = title
        #self.axes.hold(False)
        #self.fig.canvas.mpl_connect('pick_event', self.onpick)

        self.figurecanvas = FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self, gui.QSizePolicy.Expanding,
                                   gui.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        self.setWindow(
            core.QRectF(-1. * RENDER_PIXELS / 2., 1. * RENDER_PIXELS / 2.,
                        1. * RENDER_PIXELS, -1. * RENDER_PIXELS))
        #        self.origins = core.QPoint()
        self.ignorePaint = False
        #self.bottomRight = core.QPoint()
        self.rubberBand = gui.QRubberBand(gui.QRubberBand.Rectangle, self)
        self.zoomer = []
        #        self.picker = []

        self.origin = [RENDER_PIXELS, RENDER_PIXELS]
        self.scaleFactor = 1.0
        #        self.offsetX = 0.0
        #        self.offsetY = 0.0
        self.var_ = {}
        self.AZIMUTH = False
        self.RANGE_RING = False
        self.COLORBAR = True
        self.PICKER_LABEL = False
        self.cb = None
        self.cMap = colormap

        self.pressEvent = None
        self.pressed = False
        self.deltaX = 0.
        self.deltaY = 0.
        self.startX = None
        self.startY = None

        self.moveLabel = gui.QLabel("", self)
        self.moveLabel.setText("")
        self.moveLabel.hide()
        self.moveLabel.setStyleSheet(
            "font-size:12px; margin:3px; padding:4px; background:#FFFFFF; border:2px solid #000;"
        )

        self.mpl_connect('button_press_event', self.onPress)
        self.mpl_connect('button_release_event', self.onRelease)
        self.mpl_connect('motion_notify_event', self.onMove)

    def onPress(self, event):
        """ method called when mouse press"""
        if event.button == 1:  ## left button
            xdata = event.xdata
            ydata = event.ydata
            # check if mouse is outside the figure
            if xdata is None or ydata is None:
                return

            self.pressed = True
            self.pressEvent = event

            self.origin = core.QPoint(event.x, self.height() - event.y)
            self.rubberBand.setGeometry(core.QRect(self.origin, core.QSize()))
            self.rubberBand.show()

            # start point
            self.startX = xdata
            self.startY = ydata

        if event.button == 2:  ## middle botton - zoom in the center
            pass
        if event.button == 3:
            pass

    def onMove(self, event):
        """ method called when mouse moves """
        xdata = event.xdata
        ydata = event.ydata
        if xdata is None or ydata is None:
            self.moveLabel.hide()
            return

        if self.pressed:  ## display rubberband
            if self.PICKER_LABEL:
                self.moveLabel.hide()

            deltaX = event.x - self.pressEvent.x  ## moved distance
            deltaY = event.y - self.pressEvent.y  ## for rubberband
            dx = dy = min(fabs(deltaX), fabs(deltaY))
            if deltaX < 0:
                dx = -dx
            if deltaY < 0:
                dy = -dy
            newRect = core.QRect(self.origin.x(), self.origin.y(), int(dx),
                                 -int(dy))
            newRect = newRect.normalized()
            self.rubberBand.setGeometry(newRect)
            self.deltaX = dx
            self.deltaY = dy

        else:  ## display label
            if self.PICKER_LABEL:
                i, j = self.retrieve_z_value(xdata, ydata)
                self.moveLabel.show()
                if i is not None and j is not None:
                    #                    self.moveLabel.setText(core.QString(r"x=%g, y=%g, z=%g" % (xdata,ydata,self.var_[i][j]))) ## TODO: should use xdata or self.x[i][j]
                    self.moveLabel.setText(
                        r"x=%g, y=%g, z=%g" % (xdata, ydata, self.var_[i][j])
                    )  ## TODO: should use xdata or self.x[i][j]

                else:
                    #                    self.moveLabel.setText(core.QString(r"x=%g, y=%g, z=n/a" % (xdata,ydata)))
                    self.moveLabel.setText(r"x=%g, y=%g, z=n/a" %
                                           (xdata, ydata))
                self.moveLabel.adjustSize()
                offset = 10
                if self.width() - event.x < self.moveLabel.width():
                    offset = -10 - self.moveLabel.width()
                self.moveLabel.move(event.x + offset, self.height() - event.y)

    def retrieve_z_value(self, xdata, ydata):
        #xpos = np.argmin(np.abs(xdata-self.x))
        #ypos = np.argmin(np.abs(ydata-self.y))
        MIN = 99999
        iv = None
        jv = None
        for i in range(len(self.x)):
            j = self.findNearest(np.copy(self.x[i]), xdata)
            if j is not None:
                d = self.distance(xdata, ydata, self.x[i][j], self.y[i][j])
                if d < MIN:
                    iv = i
                    jv = j
                    MIN = d
        return iv, jv

    def onRelease(self, event):
        """ method called when mouse button is released """
        if event.button == 1:
            self.pressed = False
            self.rubberBand.hide()

            xdata = event.xdata  ## mouse real position
            ydata = event.ydata
            if xdata is None or ydata is None or self.startX is None or self.startY is None:
                return

            d0 = self.width() * FIGURE_CANCAS_RATIO
            x_range = self.axes.get_xlim()[1] - self.axes.get_xlim()[0]
            y_range = self.axes.get_ylim()[1] - self.axes.get_ylim()[0]
            (x1, y1) = self.startX, self.startY
            (
                x2, y2
            ) = x1 + self.deltaX / d0 * x_range, y1 + self.deltaY / d0 * y_range

            oldRect = core.QRectF()  # last rubberband rect
            oldRect.setLeft(self.axes.get_xlim()[0])
            oldRect.setRight(self.axes.get_xlim()[1])
            oldRect.setBottom(self.axes.get_ylim()[0])
            oldRect.setTop(self.axes.get_ylim()[1])

            rect = core.QRectF()  # current rubberband rect
            rect.setLeft(min(x1, x2))
            rect.setRight(max(x1, x2))
            rect.setBottom(min(y1, y2))
            rect.setTop(max(y1, y2))

            ## react only when draged region is greater than 0.01 times of old rect
            if fabs(self.deltaX)>ZOOM_WINDOW_PIXEL_LIMIT and \
               fabs(rect.width())>ZOOM_WINDOW_WIDTH_LIMIT and \
               fabs(rect.width()) >= 0.01*fabs(oldRect.width()):
                self.zoomer.append(oldRect)
                self.zoomTo(rect)
                self._zoomWindow = rect

    def zoomTo(self, rect):
        """ adjust zoom winodw to rect """
        self.axes.set_xlim(rect.left(), rect.right())
        self.axes.set_ylim(rect.bottom(), rect.top())
        self.draw()

    def findNearest(self, array, target):
        """ find nearest value to target and return its index """
        diff = abs(array - target)
        mask = np.ma.greater(
            diff, 0.151
        )  ## TODO: select a threshold (range:meters_between_gates = 150.000005960464)
        if np.all(mask):
            return None  # returns None if target is greater than any value
        masked_diff = np.ma.masked_array(diff, mask)
        return masked_diff.argmin()

    def distance(self, x1, y1, x2, y2):
        """ calculate distance between two points """
        return sqrt((x1 - x2)**2 + (y1 - y2)**2)  ## TODO: formula

    def sizeHint(self):
        w, h = self.get_width_height()
        return core.QSize(w, h)

    def minimumSizeHint(self):
        return core.QSize(10, 10)

    def setWindow(self, window):
        """ initialize the full window to use for this widget """
        self._zoomWindow = window
        self._aspectRatio = window.width() / window.height()

    def resizeEvent(self, event):
        """ method called when resize window """
        sz = event.size()
        width = sz.width()
        height = sz.height()
        dpival = self.fig.dpi
        winch = float(width) / dpival
        hinch = float(height) / dpival
        self.fig.set_size_inches(winch, hinch)
        #self.draw()
        #self.update()
        self.fig.canvas.draw()
        self.origin = [width, height]

    def drawSweep(self, sweep, varName, beamWidth):
        """ draw sweep """
        self.beamWidth = beamWidth
        self.ranges = sweep.ranges
        self.sweep = sweep
        self.varName = varName.lower()
        self.var_ = sweep.vars_[varName]  #in list
        self.x = sweep.x
        self.y = sweep.y
        self.label = self.headerLabel + sweep.timeLabel
        self.update_figure()  #update figure

    def update_figure(self):
        """ update figure - need to call it explicitly """
        if len(self.var_) > 0:
            self.axes.clear()

            vmin = min(min(x) for x in self.var_)
            vmax = max(max(x) for x in self.var_)

            im = self.axes.pcolormesh(self.x,
                                      self.y,
                                      self.var_,
                                      vmin=vmin,
                                      vmax=vmax,
                                      cmap=self.cMap(self.varName))
            ## setup zeniths, azimuths, and colorbar
            if self.RANGE_RING:
                self.draw_range_ring()
            if self.AZIMUTH:
                self.draw_azimuth_line()
            if self.COLORBAR:
                self.draw_colorbar(im, vmin, vmax)
            #self.x[0:359]/1e3,self.y[0:359]/1e3,self.var_,vmin=vmin, vmax=vmax)

            #plt.axis('off') ## show x, y axes or not
            #self.adjustZoomWindow() ## zoomWindow will not change for different variable - keep using the current zoom window
            self.zoomTo(self._zoomWindow)
            self.axes.set_title(self.label,
                                size=9)  ## TODO: change size to be adaptive
            self.fig.canvas.draw()
            ## draw contour - a new feature - grayscale, no zoom in/out support
            ## self.axes.contour(self.x,self.y,self.var_,[0.5], linewidths=2., colors='k')
            #self.fig.canvas.blit(self.axes.bbox)

    def draw_azimuth_line(self):
        """ draw azimuths with 30-degree intervals """
        angles = np.arange(0, 360, 30)
        labels = [90, 60, 30, 0, 330, 300, 270, 240, 210, 180, 150, 120]
        x = R * np.cos(np.pi * angles / 180)
        y = R * np.sin(np.pi * angles / 180)

        for xi, yi, ang, lb in zip(x, y, angles, labels):
            line = plt.Line2D([0, xi], [0, yi],
                              linestyle='dashed',
                              color='lightgray',
                              lw=0.8)
            self.axes.add_line(line)
            xo, yo = 0, 0
            if ang > 90 and ang < 180:
                xo = -10
                yo = 3
            elif ang == 180:
                xo = -15
                yo = -3
            elif ang > 180 and ang < 270:
                xo = -12
                yo = -10
            elif ang == 270:
                xo = -10
                yo = -8
            elif ang > 270 and ang < 360:
                yo = -5
            self.axes.annotate(str(lb),
                               xy=(xi, yi),
                               xycoords='data',
                               xytext=(xo, yo),
                               textcoords='offset points',
                               arrowprops=None,
                               size=10)

    def draw_range_ring(self):
        """ draw zeniths with 30 intervals """
        zeniths = np.arange(0, R + 1, 30)
        angle = 135.
        for r in zeniths:
            circ = plt.Circle((0, 0),
                              radius=r,
                              linestyle='dashed',
                              color='lightgray',
                              lw=0.8,
                              fill=False)
            self.axes.add_patch(circ)
            x = R * np.cos(np.pi * angle / 180.) * r / R
            y = R * np.sin(np.pi * angle / 180.) * r / R
            print 'r=', r, x, y
            self.axes.annotate(int(r),
                               xy=(x, y),
                               xycoords='data',
                               arrowprops=None,
                               size=10)

    def draw_colorbar(self, im, vmin, vmax):
        """ draw colorbar """
        if self.cb:
            self.fig.delaxes(self.fig.axes[1])
            self.fig.subplots_adjust(right=0.90)

        pos = self.axes.get_position()
        l, b, w, h = pos.bounds
        cax = self.fig.add_axes([l, b - 0.06, w, 0.03])  # colorbar axes
        cmap = self.cMap(self.varName)
        substName = self.varName
        if not self.cMap.ticks_label.has_key(self.varName):
            # we couldn't find 'vel_f', so try searching for 'vel'
            u = self.varName.find('_')
            if u:
                substName = self.varName[:u]
                if not self.cMap.ticks_label.has_key(substName):

                    msgBox = gui.QMessageBox()
                    msgBox.setText(
                        """ Please define a color scale for '{0}' in your configuration file """
                        .format(self.varName))
                    msgBox.exec_()
                    raise RuntimeError(
                        """ Please define a color scale for '{0}' in your configuration file """
                        .format(self.varName))
        bounds = self.cMap.ticks_label[substName]
        norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
        self.cb = ColorbarBase(
            cax,
            cmap=cmap,
            norm=norm,
            orientation='horizontal',
            boundaries=bounds,
            ticks=bounds
        )  #, format='%1i')  ## spacing='proportional' -- divide proportionally by the value
        self.cb.ax.tick_params(labelsize=8)
        #t = [str(int(i)) for i in bounds]
        t = [str(i) for i in bounds]
        self.cb.set_ticklabels(t, update_ticks=True)
        self.cb.set_label('Color Scale', size=8)

    def resetFactors(self):
        """ reset factors """
        self.zoomer = []
        self.setWindow(
            core.QRect(-1 * RENDER_PIXELS / 2, 1 * RENDER_PIXELS / 2,
                       1 * RENDER_PIXELS, 1 * RENDER_PIXELS))
        #        self.update_figure()
        self.fig.canvas.draw()

    def changeZoomerPointer(self, ind=None):
        """ method called when mouse button is pressed, changing zoomer pointer """
        if ind is None:
            if len(self.zoomer) > 0:
                zoomWindow = self.zoomer[-1]
                self.zoomTo(zoomWindow)
                self.zoomer.pop()
        else:
            if len(self.zoomer) > 0:
                zoomWindow = self.zoomer[0]
                self.zoomTo(zoomWindow)
                self.zoomer = []

    def getAspectRatio(self):
        return self._aspectRatio

    def keyPressEvent(self, event):
        """ method called when key press """
        print 'RadialDisplay::keyPressEvent: ', event.key()
        if event.key() == core.Qt.Key_C:
            self.resetFactors()
            event.accept()

    '''
示例#22
0
class BaseSlicer(object):
    """ The main purpose of these class is to have auto adjust of axes size
        to the data with different layout of cuts.
    """
    # This actually encodes the figsize for only one axe
    _default_figsize = [2.2, 2.6]
    _axes_class = CutAxes

    def __init__(self, cut_coords, axes=None, black_bg=False, **kwargs):
        """ Create 3 linked axes for plotting orthogonal cuts.

            Parameters
            ----------
            cut_coords: 3 tuple of ints
                The cut position, in world space.
            axes: matplotlib axes object, optional
                The axes that will be subdivided in 3.
            black_bg: boolean, optional
                If True, the background of the figure will be put to
                black. If you wish to save figures with a black background,
                you will need to pass "facecolor='k', edgecolor='k'" to 
                pylab's savefig.

        """
        self.cut_coords = cut_coords
        if axes is None:
            axes = pl.axes((0., 0., 1., 1.))
            axes.axis('off')
        self.frame_axes = axes
        axes.set_zorder(1)
        bb = axes.get_position()
        self.rect = (bb.x0, bb.y0, bb.x1, bb.y1)
        self._black_bg = black_bg
        self._colorbar = False
        self._colorbar_width = 0.05 * bb.width
        self._colorbar_margin = dict(left=0.25 * bb.width,
                                     right=0.02 * bb.width,
                                     top=0.05 * bb.height,
                                     bottom=0.05 * bb.height)
        self._init_axes(**kwargs)

    @staticmethod
    def find_cut_coords(img=None, threshold=None, cut_coords=None):
        # Implement this as a staticmethod or a classmethod when
        # subclassing
        raise NotImplementedError

    @classmethod
    def init_with_figure(cls,
                         img,
                         threshold=None,
                         cut_coords=None,
                         figure=None,
                         axes=None,
                         black_bg=False,
                         leave_space=False,
                         colorbar=False,
                         **kwargs):
        # deal with "fake" 4D images
        if img is not None and img is not False:
            img = _utils.check_niimg(img, ensure_3d=True)

        cut_coords = cls.find_cut_coords(img, threshold, cut_coords)
        facecolor = 'k' if black_bg else 'w'

        if isinstance(axes, pl.Axes) and figure is None:
            figure = axes.figure
            # axes.set_axis_bgcolor(facecolor)

        if not isinstance(figure, pl.Figure):
            # Make sure that we have a figure
            figsize = cls._default_figsize[:]

            # Adjust for the number of axes
            figsize[0] *= len(cut_coords)

            # Make space for the colorbar
            if colorbar:
                figsize[0] += .7

            if leave_space:
                figsize[0] += 3.4
            figure = pl.figure(figure, figsize=figsize, facecolor=facecolor)

        if isinstance(axes, pl.Axes):
            assert axes.figure is figure, ("The axes passed are not "
                                           "in the figure")

        if axes is None:
            axes = [0., 0., 1., 1.]
            if leave_space:
                axes = [0.3, 0, .7, 1.]
        if operator.isSequenceType(axes):
            axes = figure.add_axes(axes)
        # People forget to turn their axis off, or to set the zorder, and
        # then they cannot see their slicer
        axes.axis('off')
        return cls(cut_coords, axes, black_bg, **kwargs)

    def title(self,
              text,
              x=0.01,
              y=0.99,
              size=15,
              color=None,
              bgcolor=None,
              alpha=1,
              **kwargs):
        """ Write a title to the view.

            Parameters
            ----------
            text: string
                The text of the title
            x: float, optional
                The horizontal position of the title on the frame in 
                fraction of the frame width.
            y: float, optional
                The vertical position of the title on the frame in 
                fraction of the frame height.
            size: integer, optional
                The size of the title text.
            color: matplotlib color specifier, optional
                The color of the font of the title.
            bgcolor: matplotlib color specifier, optional
                The color of the background of the title.
            alpha: float, optional
                The alpha value for the background.
            kwargs:
                Extra keyword arguments are passed to matplotlib's text
                function.
        """
        if color is None:
            color = 'k' if self._black_bg else 'w'
        if bgcolor is None:
            bgcolor = 'w' if self._black_bg else 'k'
        if hasattr(self, '_cut_displayed'):
            first_axe = self._cut_displayed[0]
        else:
            first_axe = self.cut_coords[0]
        ax = self.axes[first_axe].ax
        ax.text(x,
                y,
                text,
                transform=self.frame_axes.transAxes,
                horizontalalignment='left',
                verticalalignment='top',
                size=size,
                color=color,
                bbox=dict(boxstyle="square,pad=.3",
                          ec=bgcolor,
                          fc=bgcolor,
                          alpha=alpha),
                zorder=1000,
                **kwargs)
        ax.set_zorder(1000)

    def add_overlay(self, img, threshold=1e-6, colorbar=False, **kwargs):
        """ Plot a 3D map in all the views.

            Parameters
            -----------
            img: Niimg-like object
                See http://nilearn.github.io/building_blocks/manipulating_mr_images.html#niimg.
                If it is a masked array, only the non-masked part will be
                plotted.
            threshold : a number, None
                If None is given, the maps are not thresholded.
                If a number is given, it is used to threshold the maps:
                values below the threshold (in absolute value) are
                plotted as transparent.
            colorbar: boolean, optional
                If True, display a colorbar on the right of the plots.
            kwargs:
                Extra keyword arguments are passed to imshow.
        """
        if colorbar and self._colorbar:
            raise ValueError("This figure already has an overlay with a "
                             "colorbar.")
        else:
            self._colorbar = colorbar

        img = _utils.check_niimg(img, ensure_3d=True)

        if threshold is not None:
            data = img.get_data()
            if threshold == 0:
                data = np.ma.masked_equal(data, 0, copy=False)
            else:
                data = np.ma.masked_inside(data,
                                           -threshold,
                                           threshold,
                                           copy=False)
            img = nibabel.Nifti1Image(data, img.get_affine())

        # To make sure that add_overlay has a consistant default behavior
        # with plot_stat_map
        kwargs.setdefault('interpolation', 'nearest')
        ims = self._map_show(img, type='imshow', **kwargs)

        if colorbar:
            self._colorbar_show(ims[0], threshold)

        pl.draw_if_interactive()

    def add_contours(self, img, **kwargs):
        """ Contour a 3D map in all the views.

            Parameters
            -----------
            img: Niimg-like object
                See http://nilearn.github.io/building_blocks/manipulating_mr_images.html#niimg.
                Provides image to plot.
            kwargs:
                Extra keyword arguments are passed to contour, see the
                documentation of pylab.contour
                Useful, arguments are typical "levels", which is a
                list of values to use for plotting a contour, and
                "colors", which is one color or a list of colors for
                these contours.
        """
        self._map_show(img, type='contour', **kwargs)
        pl.draw_if_interactive()

    def _map_show(self,
                  img,
                  type='imshow',
                  resampling_interpolation='continuous',
                  **kwargs):
        img = reorder_img(img, resample=resampling_interpolation)

        affine = img.get_affine()
        data = img.get_data()
        data_bounds = get_bounds(data.shape, affine)
        (xmin, xmax), (ymin, ymax), (zmin, zmax) = data_bounds

        xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \
                                        xmin, xmax, ymin, ymax, zmin, zmax

        if hasattr(data, 'mask') and isinstance(data.mask, np.ndarray):
            not_mask = np.logical_not(data.mask)
            xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \
                    get_mask_bounds(nibabel.Nifti1Image(not_mask.astype(np.int),
                                    affine))

        data_2d_list = []
        for display_ax in self.axes.itervalues():
            try:
                data_2d = display_ax.transform_to_2d(data, affine)
            except IndexError:
                # We are cutting outside the indices of the data
                data_2d = None

            data_2d_list.append(data_2d)

        if 'vmin' not in kwargs:
            kwargs['vmin'] = min(d.min() for d in data_2d_list
                                 if d is not None)
        if 'vmax' not in kwargs:
            kwargs['vmax'] = max(d.max() for d in data_2d_list
                                 if d is not None)

        bounding_box = (xmin_, xmax_), (ymin_, ymax_), (zmin_, zmax_)

        ims = []
        to_iterate_over = zip(self.axes.values(), data_2d_list)
        for display_ax, data_2d in to_iterate_over:
            if data_2d is not None:
                im = display_ax.draw_2d(data_2d,
                                        data_bounds,
                                        bounding_box,
                                        type=type,
                                        **kwargs)
                ims.append(im)
        return ims

    def _colorbar_show(self, im, threshold):
        if threshold is None:
            offset = 0
        else:
            offset = threshold
        if offset > im.norm.vmax:
            offset = im.norm.vmax

        # create new  axis for the colorbar
        figure = self.frame_axes.figure
        _, y0, x1, y1 = self.rect
        height = y1 - y0
        x_adjusted_width = self._colorbar_width / len(self.axes)
        x_adjusted_margin = self._colorbar_margin['right'] / len(self.axes)
        lt_wid_top_ht = [
            x1 - (x_adjusted_width + x_adjusted_margin),
            y0 + self._colorbar_margin['top'], x_adjusted_width, height -
            (self._colorbar_margin['top'] + self._colorbar_margin['bottom'])
        ]
        self._colorbar_ax = figure.add_axes(lt_wid_top_ht, axis_bgcolor='w')

        our_cmap = im.cmap
        # edge case where the data has a single value
        # yields a cryptic matplotlib error message
        # when trying to plot the color bar
        nb_ticks = 5 if im.norm.vmin != im.norm.vmax else 1
        ticks = np.linspace(im.norm.vmin, im.norm.vmax, nb_ticks)
        bounds = np.linspace(im.norm.vmin, im.norm.vmax, our_cmap.N)

        # some colormap hacking
        cmaplist = [our_cmap(i) for i in range(our_cmap.N)]
        istart = int(im.norm(-offset) * (our_cmap.N - 1))
        istop = int(im.norm(offset) * (our_cmap.N - 1))
        for i in range(istart, istop):
            cmaplist[i] = (0.5, 0.5, 0.5, 1.)  # just an average gray color
        our_cmap = our_cmap.from_list('Custom cmap', cmaplist, our_cmap.N)

        self._cbar = ColorbarBase(self._colorbar_ax,
                                  ticks=ticks,
                                  norm=im.norm,
                                  orientation='vertical',
                                  cmap=our_cmap,
                                  boundaries=bounds,
                                  spacing='proportional')
        self._cbar.set_ticklabels(["%.2g" % t for t in ticks])

        self._colorbar_ax.yaxis.tick_left()
        tick_color = 'w' if self._black_bg else 'k'
        for tick in self._colorbar_ax.yaxis.get_ticklabels():
            tick.set_color(tick_color)
        self._colorbar_ax.yaxis.set_tick_params(width=0)

        self._cbar.update_ticks()

    def add_edges(self, img, color='r'):
        """ Plot the edges of a 3D map in all the views.

            Parameters
            -----------
            map: 3D ndarray
                The 3D map to be plotted. If it is a masked array, only
                the non-masked part will be plotted.
            affine: 4x4 ndarray
                The affine matrix giving the transformation from voxel
                indices to world space.
            color: matplotlib color: string or (r, g, b) value
                The color used to display the edge map
        """
        img = reorder_img(img)
        data = img.get_data()
        affine = img.get_affine()
        single_color_cmap = colors.ListedColormap([color])
        data_bounds = get_bounds(data.shape, img.get_affine())

        # For each ax, cut the data and plot it
        for display_ax in self.axes.itervalues():
            try:
                data_2d = display_ax.transform_to_2d(data, affine)
                edge_mask = _edge_map(data_2d)
            except IndexError:
                # We are cutting outside the indices of the data
                continue
            display_ax.draw_2d(edge_mask,
                               data_bounds,
                               data_bounds,
                               type='imshow',
                               cmap=single_color_cmap)

        pl.draw_if_interactive()

    def annotate(self, left_right=True, positions=True, size=12, **kwargs):
        """ Add annotations to the plot.

            Parameters
            ----------
            left_right: boolean, optional
                If left_right is True, annotations indicating which side
                is left and which side is right are drawn.
            positions: boolean, optional
                If positions is True, annotations indicating the
                positions of the cuts are drawn.
            size: integer, optional
                The size of the text used.
            kwargs:
                Extra keyword arguments are passed to matplotlib's text
                function.
        """
        kwargs = kwargs.copy()
        if not 'color' in kwargs:
            if self._black_bg:
                kwargs['color'] = 'w'
            else:
                kwargs['color'] = 'k'

        for display_ax in self.axes.values():
            if self._black_bg:
                # Remove transparency to avoid slice lines intersecting w/label
                bg_color = (display_ax.ax.get_axis_bgcolor()
                            or display_ax.ax.get_figure().get_facecolor())
            else:
                bg_color = None

            if left_right:
                display_ax.draw_left_right(size=size,
                                           bg_color=bg_color,
                                           **kwargs)
            if positions:
                display_ax.draw_position(size=size,
                                         bg_color=bg_color,
                                         **kwargs)

    def close(self):
        """ Close the figure. This is necessary to avoid leaking memory.
        """
        pl.close(self.frame_axes.figure.number)

    def savefig(self, filename, dpi=None):
        """ Save the figure to a file

            Parameters
            ==========
            filename: string
                The file name to save to. It's extension determines the
                file type, typically '.png', '.svg' or '.pdf'.

            dpi: None or scalar
                The resolution in dots per inch.
        """
        facecolor = edgecolor = 'k' if self._black_bg else 'w'
        self.frame_axes.figure.savefig(filename,
                                       dpi=dpi,
                                       facecolor=facecolor,
                                       edgecolor=edgecolor)
示例#23
0
def change_cbar_text(cbar: ColorbarBase, tick: List[Number_T], text: List[str]):
    cbar.set_ticks(tick)
    cbar.set_ticklabels(text)
    cmap = mpl.colors.ListedColormap(colors)

    bounds = list(range(1, 9))
    boundaries = [
        '0', '9', '497', '1,648', '3,842', '6,613', '14,546', '299,158'
    ]

    cb = ColorbarBase(cb_ax,
                      cmap=cmap,
                      boundaries=bounds,
                      ticks=bounds,
                      label=boundaries,
                      orientation='vertical')

    cb.set_label('')
    # true label
    cb_ax.text(.7, 8.8, 'New Cases', size=14, fontdict={'family': 'helvetica'})
    cb.set_ticklabels(boundaries)

    ax.set_title(f'COVID Cases in 2020 at Week {wk}',
                 size=18,
                 fontdict={
                     'fontweight': 'bold',
                     'family': 'helvetica'
                 })

    plt.savefig(f'covid_maps/daily_week{wk}', dpi=500, bbox_inches='tight')

    plt.clf()

sc.stop()
示例#25
0
def change_cbar_text(cbar: ColorbarBase, tick: list, text: list):
    cbar.set_ticks(tick)
    cbar.set_ticklabels(text)
示例#26
0
class Bar_Plot(QtGui.QWidget):
    def __init__(self, url=None):
        super().__init__()

        if url is None:
            url = "ws://localhost:7777"

        self.ws_data = WS_PB(url=url, plot_name="PB_bar")
        self.setWindowTitle("Bar Plot")
        self.timer_interval = 0.5

        while self.ws_data.ch_label is None:
            pass

        self.num_color = 8
        self.big_plots = [None] * len(self.ws_data.ch_label)
        self.init_tick = None
        self.cur_tick = None

        self.init_ui()
        self.setup_signal_handler()
        self.show()
        self.value = 0

    def init_ui(self):
        self.cmap = mpl.cm.get_cmap('Dark2')
        self.norm = mpl.colors.Normalize(vmin=0, vmax=8)
        self.fig, (self.axes) = plt.subplots(len(self.ws_data.ch_label) + 2, 1)

        self.colorbar = ColorbarBase(self.axes[-1],
                                     cmap=self.cmap,
                                     norm=self.norm,
                                     ticks=[i + 0.5 for i in range(8)])
        self.colorbar.set_ticklabels([
            'Delta (1-3 Hz)', 'Theta (4-7 Hz)', 'Low Alpha (8-10 Hz)',
            'High Alpha (11-12 Hz)', 'Low Beta (13-15 Hz)',
            'Mid Beta (16-19 Hz)', 'High Beta (20-35 Hz)', 'Gamma (36-50 Hz)'
        ])

        self.canvas = FigureCanvas(self.fig)

        self.pos = list(channel_dict_2D.values())  # get all X,Y values
        self.ch_names_ = list(
            channel_dict_2D.keys())  # get all channel's names

        self.pos, self.outlines = topomap._check_outlines(
            self.pos, 'head')  # from mne.viz libs, normalize the pos

        topomap._draw_outlines(self.axes[0], self.outlines)
        self.plt_idx = [
            self.ch_names_.index(name) for name in self.ws_data.ch_label
        ]  # get the index of those required channels
        ch_pos = [self.pos[idx] * 5 / 6 for idx in self.plt_idx]
        for idx, pos in enumerate(ch_pos):
            pos += [0.47, 0.47]
            self.axes[idx + 1].set_position(list(pos) + [0.06, 0.06])
            self.axes[idx + 1].axis("off")
            self.axes[idx + 1].grid(True, axis='y')

        self.axes[0].set_position([0, 0, 1, 1])
        self.axes[-1].set_position([0.80, 0.85, 0.03, 0.13])
        self.resize(1000, 800)

        hlayout = QtGui.QHBoxLayout(self)
        hlayout.addWidget(self.canvas)

    def setup_signal_handler(self):
        cid = self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.timer = QtCore.QTimer()
        self.timer.setInterval(self.timer_interval * 1000)
        self.timer.timeout.connect(self.draw)
        self.timer.start()

    def onclick(self, event):
        for idx, ax in enumerate(self.axes[1:]):
            if ax == event.inaxes:
                print(self.ws_data.ch_label[idx])
                if self.big_plots[idx] == None:
                    self.big_plots[idx] = Big_Bar_Plot(
                        self, self.ws_data.ch_label[idx])

    def big_plot_closed(self, plot):
        for idx, p in enumerate(self.big_plots):
            if plot == p:
                self.big_plots[idx] = None

    def draw(self):
        if self.init_tick is None and self.ws_data.ticks:
            self.init_tick = self.ws_data.ticks[0]
            self.cur_tick = self.init_tick

        # Delete outdated
        count = 0
        if self.cur_tick:
            self.cur_tick += 500

            for tick in self.ws_data.ticks:
                if tick <= self.cur_tick + 250:
                    count += 1

        while count > 1:
            self.ws_data.ticks.pop(0)
            self.ws_data.power_data.pop(0)
            self.ws_data.z_all_data.pop(0)
            self.ws_data.z_each_data.pop(0)
            count -= 1

        if self.ws_data.power_data:
            ch_data = self.ws_data.power_data.pop(0)
            for idx, data in enumerate(ch_data):
                color = [self.cmap(self.norm(i)) for i in range(8)]
                self.axes[idx + 1].cla()
                self.axes[idx + 1].bar(list(range(8)), data, color=color)
                self.axes[idx + 1].spines["bottom"].set_visible(False)
                self.axes[idx + 1].spines["top"].set_visible(False)
                self.axes[idx + 1].spines["right"].set_visible(False)
                self.axes[idx + 1].spines["left"].set_visible(False)
                self.axes[idx + 1].yaxis.grid(True, c='r')
                self.axes[idx + 1].get_xaxis().set_visible(False)
                if self.big_plots[idx] != None:
                    self.big_plots[idx].draw(data, color)
            self.fig.canvas.draw()

        if self.ws_data.ticks:
            tick = self.ws_data.ticks.pop(0)

        if self.ws_data.z_all_data:
            z_all = self.ws_data.z_all_data.pop(0)

        if self.ws_data.z_each_data:
            z_each = self.ws_data.z_each_data.pop(0)

        gc.collect()