Example #1
0
 def get_dict_from_marker_style(line):
     style_dict = {"faceColor": to_hex(line.get_markerfacecolor()),
                   "edgeColor": to_hex(line.get_markeredgecolor()),
                   "edgeWidth": line.get_markeredgewidth(),
                   "markerType": line.get_marker(),
                   "markerSize": line.get_markersize(),
                   "zOrder": line.get_zorder()}
     return style_dict
Example #2
0
    def plot_cdf(self, workload='jankbench', metric='frame_total_duration',
                 threshold=16, tag='.*', kernel='.*', test='.*'):
        """
        Display cumulative distribution functions of a certain metric

        Draws CDFs of metrics in the results. Check ``workloads`` and
        ``workload_available_metrics`` to find the available workloads and
        metrics. Check ``tags``, ``tests`` and ``kernels`` to find the
        names that results can be filtered against.

        The most likely use-case for this is plotting frame rendering times
        under Jankbench, so default parameters are provided to make this easy.

        :param workload: Name of workload to display metrics for
        :param metric: Name of metric to display

        :param threshold: Value to highlight in the plot - the likely use for
                          this is highlighting the maximum acceptable
                          frame-rendering time in order to see at a glance the
                          rough proportion of frames that were rendered in time.

        :param tag: regular expression to filter tags that should be plotted
        :param kernel: regular expression to filter kernels that should be plotted
        :param tag: regular expression to filter tags that should be plotted

        :param by: List of identifiers to group output as in DataFrame.groupby.
        """
        df = self._get_metric_df(workload, metric, tag, kernel, test)
        if df is None:
            return

        test_cnt = len(df.groupby(['test', 'tag', 'kernel']))
        colors = iter(cm.rainbow(np.linspace(0, 1, test_cnt+1)))

        fig, axes = plt.subplots()
        axes.axvspan(0, threshold, facecolor='g', alpha=0.1);

        labels = []
        lines = []
        for keys, df in df.groupby(['test', 'tag', 'kernel']):
            labels.append("{:16s}: {:32s}".format(keys[2], keys[1]))
            color = next(colors)
            cdf = self._get_cdf(df['value'], threshold)
            [units] = df['units'].unique()
            ax = cdf.df.plot(ax=axes, legend=False, xlim=(0,None), figsize=(16, 6),
                             title='Total duration CDF ({:.1f}% within {} [{}] threshold)'\
                             .format(100. * cdf.below, threshold, units),
                             label=test,
                             color=to_hex(color))
            lines.append(ax.lines[-1])
            axes.axhline(y=cdf.below, linewidth=1,
                         linestyle='--', color=to_hex(color))
            self._log.debug("%-32s: %-32s: %.1f", keys[2], keys[1], 100.*cdf.below)

        axes.grid(True)
        axes.legend(lines, labels)
        plt.show()
Example #3
0
def test_cn():
    matplotlib.rcParams['axes.prop_cycle'] = cycler('color',
                                                    ['blue', 'r'])
    assert mcolors.to_hex("C0") == '#0000ff'
    assert mcolors.to_hex("C1") == '#ff0000'

    matplotlib.rcParams['axes.prop_cycle'] = cycler('color',
                                                    ['xkcd:blue', 'r'])
    assert mcolors.to_hex("C0") == '#0343df'
    assert mcolors.to_hex("C1") == '#ff0000'
Example #4
0
def test_conversions():
    # to_rgba_array("none") returns a (0, 4) array.
    assert_array_equal(mcolors.to_rgba_array("none"), np.zeros((0, 4)))
    # alpha is properly set.
    assert_equal(mcolors.to_rgba((1, 1, 1), .5), (1, 1, 1, .5))
    # builtin round differs between py2 and py3.
    assert_equal(mcolors.to_hex((.7, .7, .7)), "#b2b2b2")
    # hex roundtrip.
    hex_color = "#1234abcd"
    assert_equal(mcolors.to_hex(mcolors.to_rgba(hex_color), keep_alpha=True),
                 hex_color)
Example #5
0
def _get_motif_tree(tree, data, circle=True, vmin=None, vmax=None):
    try:
        from ete3 import Tree, NodeStyle, TreeStyle
    except ImportError:
        print("Please install ete3 to use this functionality")
        sys.exit(1)

    t = Tree(tree)
    
    # Determine cutoff for color scale
    if not(vmin and vmax):
        for i in range(90, 101):
            minmax = np.percentile(data.values, i)
            if minmax > 0:
                break
    if not vmin:
        vmin = -minmax
    if not vmax:
        vmax = minmax
    
    norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
    mapper = cm.ScalarMappable(norm=norm, cmap="RdBu_r")
    
    m = 25 / data.values.max()
    
    for node in t.traverse("levelorder"):
        val = data[[l.name for l in node.get_leaves()]].values.mean()
        style = NodeStyle()
        style["size"] = 0
        
        style["hz_line_color"] = to_hex(mapper.to_rgba(val))
        style["vt_line_color"] = to_hex(mapper.to_rgba(val))
        
        v = max(np.abs(m * val), 5)
        style["vt_line_width"] = v
        style["hz_line_width"] = v

        node.set_style(style)
    
    ts = TreeStyle()

    ts.layout_fn = _tree_layout
    ts.show_leaf_name= False
    ts.show_scale = False
    ts.branch_vertical_margin = 10

    if circle:
        ts.mode = "c"
        ts.arc_start = 180 # 0 degrees = 3 o'clock
        ts.arc_span = 180
    
    return t, ts
Example #6
0
    def to_rgba_hex(c, a):
        """
        Conver rgb color to rgba hex value

        If color c has an alpha channel, then alpha value
        a is ignored
        """
        _has_alpha = has_alpha(c)
        c = mcolors.to_hex(c, keep_alpha=_has_alpha)

        if not _has_alpha:
            arr = colorConverter.to_rgba(c, a)
            return mcolors.to_hex(arr, keep_alpha=True)

        return c
 def cell_color(val):
     if val>.9:
         return 'background-color: None'
     else:
         cm =sns.color_palette('Reds', n_colors=100)[::-1]
         color = to_hex(cm[int((val-min_robustness)*50)])
         return 'background-color: %s' % color
Example #8
0
def swap_colors(json_file_path):
    '''
    Switches out color ramp in meta.json files.
    Uses custom color ramp if provided and valid; otherwise falls back to nextstrain default colors.
    N.B.: Modifies json in place and writes to original file path.
    '''
    j = json.load(open(json_file_path, 'r'))
    color_options = j['color_options']

    for k,v in color_options.items():
        if 'color_map' in v:
            categories, colors = zip(*v['color_map'])

            ## Use custom colors if provided AND present for all categories in the dataset
            if custom_colors and all([category in custom_colors for category in categories]):
                colors = [ custom_colors[category] for category in categories ]

            ## Expand the color palette if we have too many categories
            elif len(categories) > len(default_colors):
                from matplotlib.colors import LinearSegmentedColormap, to_hex
                from numpy import linspace
                expanded_cmap = LinearSegmentedColormap.from_list('expanded_cmap', default_colors[-1], N=len(categories))
                discrete_colors = [expanded_cmap(i) for i in linspace(0,1,len(categories))]
                colors = [to_hex(c).upper() for c in discrete_colors]

            else: ## Falls back to default nextstrain colors
                colors = default_colors[len(categories)] # based on how many categories are present; keeps original ordering

            j['color_options'][k]['color_map'] = map(list, zip(categories, colors))

    json.dump(j, open(json_file_path, 'w'), indent=1)
Example #9
0
def test_color_cycle():
    cyc = plot_utils.color_cycle()
    assert isinstance(cyc, itertools.cycle)
    if mpl_version < '1.5':
        assert next(cyc) == 'b'
    else:
        assert next(cyc) == mpl_colors.to_hex("C0")
Example #10
0
def test_conversions():
    # to_rgba_array("none") returns a (0, 4) array.
    assert_array_equal(mcolors.to_rgba_array("none"), np.zeros((0, 4)))
    # a list of grayscale levels, not a single color.
    assert_array_equal(
        mcolors.to_rgba_array([".2", ".5", ".8"]),
        np.vstack([mcolors.to_rgba(c) for c in [".2", ".5", ".8"]]))
    # alpha is properly set.
    assert mcolors.to_rgba((1, 1, 1), .5) == (1, 1, 1, .5)
    assert mcolors.to_rgba(".1", .5) == (.1, .1, .1, .5)
    # builtin round differs between py2 and py3.
    assert mcolors.to_hex((.7, .7, .7)) == "#b2b2b2"
    # hex roundtrip.
    hex_color = "#1234abcd"
    assert mcolors.to_hex(mcolors.to_rgba(hex_color), keep_alpha=True) == \
        hex_color
def get_dendrogram_color_fun(Z, labels, clusters, color_palette=sns.hls_palette):
    """ return the color function for a dendrogram
    
    ref: https://stackoverflow.com/questions/38153829/custom-cluster-colors-of-scipy-dendrogram-in-python-link-color-func
    Args:
        Z: linkage 
        Labels: list of labels in the order of the dendrogram. They should be
            the index of the original clustered list. I.E. [0,3,1,2] would
            be the labels list - the original list reordered to the order of the leaves
        clusters: cluster assignments for the labels in the original order
    
    """
    dflt_col = "#808080" # Unclustered gray
    color_palette = color_palette(len(np.unique(clusters)))
    D_leaf_colors = {i: to_hex(color_palette[clusters[i]-1]) for i in labels}
    # notes:
    # * rows in Z correspond to "inverted U" links that connect clusters
    # * rows are ordered by increasing distance
    # * if the colors of the connected clusters match, use that color for link
    link_cols = {}
    for i, i12 in enumerate(Z[:,:2].astype(int)):
      c1, c2 = (link_cols[x] if x > len(Z) else D_leaf_colors[x]
        for x in i12)
      link_cols[i+1+len(Z)] = c1 if c1 == c2 else dflt_col
    return lambda x: link_cols[x], color_palette
 def plotIntersection(self, eq1, eq2, line_type='-',color='Blue'):
     """
     plot the intersection of two linear equations in 3d
     """
     hex_color = colors.to_hex(color)
     bounds = np.array([self.ax.axes.get_xlim(),
                            self.ax.axes.get_ylim(),
                            self.ax.axes.get_zlim()])
     tmp = np.array([np.array(eq1), np.array(eq2)])
     A = tmp[:,:-1]
     b = tmp[:,-1]
     ptlist = []
     for i in range(3):
         vars = [k for k in range(3) if k != i]
         A2 = A[:][:,vars]
         for j in range(2):
             b2 = b - bounds[i,j] * A[:,i]
             try:
                 pt = np.linalg.inv(A2).dot(b2)
             except:
                 continue
             if ((pt[0] >= bounds[vars[0]][0])
                 & (pt[0] <= bounds[vars[0]][1])
                 & (pt[1] >= bounds[vars[1]][0])
                 & (pt[1] <= bounds[vars[1]][1])):
                 point = [0,0,0]
                 point[vars[0]] = pt[0]
                 point[vars[1]] = pt[1]
                 point[i] = bounds[i,j]
                 ptlist.append(point)
     self.plotLine(ptlist, color, line_type)
Example #13
0
def main():
    SIZE = 20
    PLOIDY = 2
    MUTATIONS = 2

    indices = range(SIZE)
    # Build fake data
    seqA = list("0" * SIZE)
    allseqs = [seqA[:] for x in range(PLOIDY)]  # Hexaploid
    for s in allseqs:
        for i in [choice(indices) for x in range(MUTATIONS)]:
            s[i] = "1"

    allseqs = [make_sequence(s, name=name) for (s, name) in \
                zip(allseqs, [str(x) for x in range(PLOIDY)])]

    # Build graph structure
    G = Graph("Assembly graph", filename="graph")
    G.attr(rankdir="LR", fontname="Helvetica", splines="true")
    G.attr(ranksep=".2", nodesep="0.02")
    G.attr('node', shape='point')
    G.attr('edge', dir='none', penwidth='4')

    colorset = get_map('Set2', 'qualitative', 8).mpl_colors
    colorset = [to_hex(x) for x in colorset]
    colors = sample(colorset, PLOIDY)
    for s, color in zip(allseqs, colors):
        sequence_to_graph(G, s, color=color)
    zip_sequences(G, allseqs)

    # Output graph
    G.view()
Example #14
0
def color2hex(color):
    try:
        from matplotlib.colors import to_hex
        result = to_hex(color)
    except ImportError:  # MPL 1.5
        from matplotlib.colors import ColorConverter, rgb2hex
        result = rgb2hex(ColorConverter().to_rgb(color))
    return result
Example #15
0
def color_to_name(color):
    """
    Translate between a matplotlib color representation
    and our string names.
    :param color: Any matplotlib color representation
    :return: The string identifier we have chosen
    :raises: ValueError if the color is not known
    """
    color_as_hex = to_hex(color)
    for name, value in iteritems(mpl_named_colors()):
        if color_as_hex == to_hex(value):
            return pretty_name(name)
    else:
        for name, hexvalue in iteritems(_BASIC_COLORS_HEX_MAPPING):
            if color_as_hex == hexvalue:
                return name
        else:
            raise ValueError("matplotlib color {} unknown".format(color))
def test_cn():
    matplotlib.rcParams['axes.prop_cycle'] = cycler('color',
                                                    ['blue', 'r'])
    assert mcolors.to_hex("C0") == '#0000ff'
    assert mcolors.to_hex("C1") == '#ff0000'

    matplotlib.rcParams['axes.prop_cycle'] = cycler('color',
                                                    ['xkcd:blue', 'r'])
    assert mcolors.to_hex("C0") == '#0343df'
    assert mcolors.to_hex("C1") == '#ff0000'

    matplotlib.rcParams['axes.prop_cycle'] = cycler('color', ['8e4585', 'r'])

    assert mcolors.to_hex("C0") == '#8e4585'
    # if '8e4585' gets parsed as a float before it gets detected as a hex
    # colour it will be interpreted as a very large number.
    # this mustn't happen.
    assert mcolors.to_rgb("C0")[0] != np.inf
Example #17
0
 def get_dict_for_grid_style(ax):
     grid_style = {}
     gridlines = ax.get_gridlines()
     if ax._gridOnMajor and len(gridlines) > 0:
         grid_style["color"] = to_hex(gridlines[0].get_color())
         grid_style["alpha"] = gridlines[0].get_alpha()
         grid_style["gridOn"] = True
     else:
         grid_style["gridOn"] = False
     return grid_style
 def plotPoint (self, x1, x2, x3, color='r', alpha=1.0):
     # do the plotting
     self.ax.plot([x1], [x2], '{}o'.format(color), zs=[x3])
     # save the graphics element
     hex_color = colors.to_hex(color)
     self.desc['objects'].append(
         {'type': 'point',
          'transparency': alpha,
          'color': hex_color,
          'points': [{'x': x1, 'y': x2, 'z': x3}]})
Example #19
0
 def get_dict_from_text_style(text):
     style_dict = {"alpha": text.get_alpha(),
                   "textSize": text.get_size(),
                   "color": to_hex(text.get_color()),
                   "hAlign": text.get_horizontalalignment(),
                   "vAlign": text.get_verticalalignment(),
                   "rotation": text.get_rotation(),
                   "zOrder": text.get_zorder()}
     if style_dict["alpha"] is None:
         style_dict["alpha"] = 1
     return style_dict
Example #20
0
 def __init__(self, color, parent=None):
     QtWidgets.QHBoxLayout.__init__(self)
     assert isinstance(color, QtGui.QColor)
     self.lineedit = QtWidgets.QLineEdit(
         mcolors.to_hex(color.getRgbF(), keep_alpha=True), parent)
     self.lineedit.editingFinished.connect(self.update_color)
     self.addWidget(self.lineedit)
     self.colorbtn = ColorButton(parent)
     self.colorbtn.color = color
     self.colorbtn.colorChanged.connect(self.update_text)
     self.addWidget(self.colorbtn)
Example #21
0
 def get_dict_from_line(self, line, index=0):
     line_dict = {"lineIndex": index,
                  "label": line.get_label(),
                  "alpha": line.get_alpha(),
                  "color": to_hex(line.get_color()),
                  "lineWidth": line.get_linewidth(),
                  "lineStyle": line.get_linestyle(),
                  "markerStyle": self.get_dict_from_marker_style(line),
                  "errorbars": self.get_dict_for_errorbars(line)}
     if line_dict["alpha"] is None:
         line_dict["alpha"] = 1
     return line_dict
Example #22
0
def named_cycle_colors():
    """
    Retrieve a named list of colors for the current color cycle
    :return: A list of colors as human-readable strings
    """
    axes_prop_cycler = rcParams['axes.prop_cycle']
    try:
        keys = axes_prop_cycler.by_key()
    except AttributeError:
        # cycler < 1 doesn't have by_key but _transpose is the same
        # and depending on a private attribute is okay here as
        # it is only for older versions that won't change
        keys = axes_prop_cycler._transpose()
    return [color_to_name(to_hex(color)) for color in keys['color']]
 def plotLine(self, in_ptlist, color, line_type='-', alpha=1.0):
     ptlist = [[float(i) for i in j] for j in in_ptlist]
     hex_color = colors.to_hex(color)
     self.desc['objects'].append({'type': 'line',
                                  'color': hex_color,
                                  'transparency': alpha,
                                  'linetype': line_type,
          'points': [{'x': p[0], 'y': p[1], 'z': p[2]} for p in ptlist]})
     ptlist = np.array(ptlist).T
     self.ax.plot(ptlist[0,:],
                      ptlist[1,:],
                      line_type,
                      zs = ptlist[2,:],
                      color=color)
 def plotLinEqn(self, l1, color='Green', alpha=0.3):
     """
     plot the plane corresponding to the linear equation
     a1 x + a2 y + a3 z = b
     where l1 = [a1, a2, a3, b]
     """
     pts = self.intersectionPlaneCube(l1)
     ptlist = np.array([np.array(i) for i in pts])
     x = ptlist[:,0]
     y = ptlist[:,1]
     z = ptlist[:,2]
     if (len(x) > 2):
         try:
             triang = mp.tri.Triangulation(x, y)
         except:
             # this happens where there are triangles parallel to
             # the z axis so some points in the x,y plane are
             # repeated (which is illegal for a triangulation)
             # this is a hack but it works!
             try:
                 triang = mp.tri.Triangulation(x, z)
                 triang.y = y
             except:
                 triang = mp.tri.Triangulation(z, y)
                 triang.x = x
         # save the graphics element
         hex_color = colors.to_hex(color)
         self.desc['objects'].append(
             {'type': 'polygonsurface',
              'color': hex_color,
              'transparency': alpha,
              'points': [{'x': p[0], 'y': p[1], 'z': p[2]} for p in pts],
              'triangleIndices': [int(y) for x in triang.triangles
                                      for y in x]})
         # do the plotting
         self.ax.plot_trisurf(triang,
                                  z,
                                  color=color,
                                  alpha=alpha,
                                  linewidth=0,
                                  shade=False)
Example #25
0
    def do_colors(code, data_column, show, name):
        colors = [colortuple(a) for a in data_column]
        if all(a is None for a in colors):
            colors, index = None, None
        else:
            # replace None values with blue colors
            colors = np.array([((0, 0, 1, 1) if a is None else a)
                               for a in colors])
            # set alpha for hidden (Qt.NoPen, Qt.NoBrush) elements to zero
            colors[:, 3][np.array(show) == 0] = 0
            # shorter color names for printout
            colors = [to_hex(c, keep_alpha=True) for c in colors]
            colors, index = index_per_different(colors)

        code.append("{} = {}".format(name, repr(colors)))
        if index is not None:
            code.append("{}_index = {}".format(name, numpy_repr_int(index)))

        decompresssed_code = name
        if index is not None:
            decompresssed_code = "array({})[{}_index]".format(name, name)
            colors = np.array(colors)[index]

        return colors, decompresssed_code
Example #26
0
def figure_edit(axes, parent=None):
    """Edit matplotlib figure options"""
    sep = (None, None)  # separator

    # Get / General
    # Cast to builtin floats as they have nicer reprs.
    xmin, xmax = map(float, axes.get_xlim())
    ymin, ymax = map(float, axes.get_ylim())
    general = [
        ('Title', axes.get_title()),
        sep,
        (None, "<b>X-Axis</b>"),
        ('Left', xmin),
        ('Right', xmax),
        ('Label', axes.get_xlabel()),
        ('Scale', [axes.get_xscale(), 'linear', 'log', 'logit']),
        sep,
        (None, "<b>Y-Axis</b>"),
        ('Bottom', ymin),
        ('Top', ymax),
        ('Label', axes.get_ylabel()),
        ('Scale', [axes.get_yscale(), 'linear', 'log', 'logit']),
        sep,
        ('(Re-)Generate automatic legend', False),
    ]

    # Save the unit data
    xconverter = axes.xaxis.converter
    yconverter = axes.yaxis.converter
    xunits = axes.xaxis.get_units()
    yunits = axes.yaxis.get_units()

    # Sorting for default labels (_lineXXX, _imageXXX).
    def cmp_key(label):
        match = re.match(r"(_line|_image)(\d+)", label)
        if match:
            return match.group(1), int(match.group(2))
        else:
            return label, 0

    # Get / Curves
    linedict = {}
    for line in axes.get_lines():
        label = line.get_label()
        if label == '_nolegend_':
            continue
        linedict[label] = line
    curves = []

    def prepare_data(d, init):
        """
        Prepare entry for FormLayout.

        *d* is a mapping of shorthands to style names (a single style may
        have multiple shorthands, in particular the shorthands `None`,
        `"None"`, `"none"` and `""` are synonyms); *init* is one shorthand
        of the initial style.

        This function returns an list suitable for initializing a
        FormLayout combobox, namely `[initial_name, (shorthand,
        style_name), (shorthand, style_name), ...]`.
        """
        if init not in d:
            d = {**d, init: str(init)}
        # Drop duplicate shorthands from dict (by overwriting them during
        # the dict comprehension).
        name2short = {name: short for short, name in d.items()}
        # Convert back to {shorthand: name}.
        short2name = {short: name for name, short in name2short.items()}
        # Find the kept shorthand for the style specified by init.
        canonical_init = name2short[d[init]]
        # Sort by representation and prepend the initial value.
        return ([canonical_init] + sorted(
            short2name.items(), key=lambda short_and_name: short_and_name[1]))

    curvelabels = sorted(linedict, key=cmp_key)
    for label in curvelabels:
        line = linedict[label]
        color = mcolors.to_hex(mcolors.to_rgba(line.get_color(),
                                               line.get_alpha()),
                               keep_alpha=True)
        ec = mcolors.to_hex(mcolors.to_rgba(line.get_markeredgecolor(),
                                            line.get_alpha()),
                            keep_alpha=True)
        fc = mcolors.to_hex(mcolors.to_rgba(line.get_markerfacecolor(),
                                            line.get_alpha()),
                            keep_alpha=True)
        curvedata = [
            ('Label', label), sep, (None, '<b>Line</b>'),
            ('Line style', prepare_data(LINESTYLES, line.get_linestyle())),
            ('Draw style', prepare_data(DRAWSTYLES, line.get_drawstyle())),
            ('Width', line.get_linewidth()), ('Color (RGBA)', color), sep,
            (None, '<b>Marker</b>'),
            ('Style', prepare_data(MARKERS, line.get_marker())),
            ('Size', line.get_markersize()), ('Face color (RGBA)', fc),
            ('Edge color (RGBA)', ec)
        ]
        curves.append([curvedata, label, ""])
    # Is there a curve displayed?
    has_curve = bool(curves)

    # Get ScalarMappables.
    mappabledict = {}
    for mappable in [*axes.images, *axes.collections]:
        label = mappable.get_label()
        if label == '_nolegend_' or mappable.get_array() is None:
            continue
        mappabledict[label] = mappable
    mappablelabels = sorted(mappabledict, key=cmp_key)
    mappables = []
    cmaps = [(cmap, name) for name, cmap in sorted(cm._cmap_registry.items())]
    for label in mappablelabels:
        mappable = mappabledict[label]
        cmap = mappable.get_cmap()
        if cmap not in cm._cmap_registry.values():
            cmaps = [(cmap, cmap.name), *cmaps]
        low, high = mappable.get_clim()
        mappabledata = [
            ('Label', label),
            ('Colormap', [cmap.name] + cmaps),
            ('Min. value', low),
            ('Max. value', high),
        ]
        if hasattr(mappable, "get_interpolation"):  # Images.
            interpolations = [(name, name)
                              for name in sorted(mimage.interpolations_names)]
            mappabledata.append(
                ('Interpolation',
                 [mappable.get_interpolation(), *interpolations]))
        mappables.append([mappabledata, label, ""])
    # Is there a scalarmappable displayed?
    has_sm = bool(mappables)

    datalist = [(general, "Axes", "")]
    if curves:
        datalist.append((curves, "Curves", ""))
    if mappables:
        datalist.append((mappables, "Images, etc.", ""))

    def apply_callback(data):
        """A callback to apply changes."""
        orig_xlim = axes.get_xlim()
        orig_ylim = axes.get_ylim()

        general = data.pop(0)
        curves = data.pop(0) if has_curve else []
        mappables = data.pop(0) if has_sm else []
        if data:
            raise ValueError("Unexpected field")

        # Set / General
        (title, xmin, xmax, xlabel, xscale, ymin, ymax, ylabel, yscale,
         generate_legend) = general

        if axes.get_xscale() != xscale:
            axes.set_xscale(xscale)
        if axes.get_yscale() != yscale:
            axes.set_yscale(yscale)

        axes.set_title(title)
        axes.set_xlim(xmin, xmax)
        axes.set_xlabel(xlabel)
        axes.set_ylim(ymin, ymax)
        axes.set_ylabel(ylabel)

        # Restore the unit data
        axes.xaxis.converter = xconverter
        axes.yaxis.converter = yconverter
        axes.xaxis.set_units(xunits)
        axes.yaxis.set_units(yunits)
        axes.xaxis._update_axisinfo()
        axes.yaxis._update_axisinfo()

        # Set / Curves
        for index, curve in enumerate(curves):
            line = linedict[curvelabels[index]]
            (label, linestyle, drawstyle, linewidth, color, marker, markersize,
             markerfacecolor, markeredgecolor) = curve
            line.set_label(label)
            line.set_linestyle(linestyle)
            line.set_drawstyle(drawstyle)
            line.set_linewidth(linewidth)
            rgba = mcolors.to_rgba(color)
            line.set_alpha(None)
            line.set_color(rgba)
            if marker != 'none':
                line.set_marker(marker)
                line.set_markersize(markersize)
                line.set_markerfacecolor(markerfacecolor)
                line.set_markeredgecolor(markeredgecolor)

        # Set ScalarMappables.
        for index, mappable_settings in enumerate(mappables):
            mappable = mappabledict[mappablelabels[index]]
            if len(mappable_settings) == 5:
                label, cmap, low, high, interpolation = mappable_settings
                mappable.set_interpolation(interpolation)
            elif len(mappable_settings) == 4:
                label, cmap, low, high = mappable_settings
            mappable.set_label(label)
            mappable.set_cmap(cm.get_cmap(cmap))
            mappable.set_clim(*sorted([low, high]))

        # re-generate legend, if checkbox is checked
        if generate_legend:
            draggable = None
            ncol = 1
            if axes.legend_ is not None:
                old_legend = axes.get_legend()
                draggable = old_legend._draggable is not None
                ncol = old_legend._ncol
            new_legend = axes.legend(ncol=ncol)
            if new_legend:
                new_legend.set_draggable(draggable)

        # Redraw
        figure = axes.get_figure()
        figure.canvas.draw()
        if not (axes.get_xlim() == orig_xlim and axes.get_ylim() == orig_ylim):
            figure.canvas.toolbar.push_current()

    data = _formlayout.fedit(
        datalist,
        title="Figure options",
        parent=parent,
        icon=QtGui.QIcon(
            str(cbook._get_data_path('images', 'qt4_editor_options.svg'))),
        apply=apply_callback)
    if data is not None:
        apply_callback(data)
        (0.65, 0.0, 0.0),
        (1.0, 0.0, 0.0)
    ]
}

# cm = cpl.LinearSegmentedColormap.from_list("", ["blue","violet","red"])
cm = cpl.LinearSegmentedColormap("", cdict)
cNorm = cpl.Normalize(vmin=0.5, vmax=3.75)
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)

xy = [[xx, yy] for (xx, yy) in zip(df.longitude, df.latitude)
      ]  #(data.abs_x, data.abs_y)

xy = [[x0, x1] for x0, x1 in zip(xy[:-1], xy[1:])]
cSegments = [scalarMap.to_rgba(c) for c in df.avgagg]
avg_agg_segments = [cpl.to_hex(scalarMap.to_rgba(c)) for c in df.agg_cumavg]

interval = 1
fig, ax = plt.subplots(figsize=(10, 8))
line = LineCollection(xy[:interval], linewidth=7, color=cSegments[:interval])
ax.add_collection(line)
ax.set(xlim=(longitude_min, longitude_max + 0.005),
       ylim=(latitude_min, latitude_max))
# ax.autoscale_view()
patch = patches.Rectangle(
    (longitude_max, (latitude_min + latitude_max) / 2),  # (x,y)
    0.005,  # width
    0.005,  # height
    facecolor=avg_agg_segments[0],
)
ax.add_patch(patch)
Example #28
0
from Bio import SeqIO
from Bio.Graphics import BasicChromosome
from Bio.SeqFeature import SeqFeature, FeatureLocation
from reportlab.lib.units import cm
from matplotlib import cm as mcm
from matplotlib.colors import to_hex

logh = open(snakemake.log[0], "w")
# {chrom: length}
entries = {}
for rec in SeqIO.parse(snakemake.params["genome"], "fasta"):
    if len(rec.seq) > 100000:
        entries[rec.id] = len(rec.seq)
# all available colors
# cols = list(colors.getAllNamedColors().values())[8:]
cols = [to_hex(mcm.tab10(c)) for c in range(len(entries))]
# {chrom: {type: [features]}} where a feature is [start, end, type, color]
features = defaultdict(list)
type_id = 0
ftype_to_col = {}
with open(snakemake.input["features"]) as bedh:
    bed = csv.DictReader(
        bedh, delimiter="\t", fieldnames=["chrom", "start", "end", "type"]
    )
    for feat in bed:
        start, end = int(feat["start"]), int(feat["end"])
        if start > end:
            start, end = end, start
        try:
            if end > entries[feat["chrom"]]:
                continue
Example #29
0
def plot_single(ax, ma, average_type, color, label, plot_type='lines'):
    """
    Adds a line to the plot in the given ax using the specified method

    Parameters
    ----------
    ax : matplotlib axis
        matplotlib axis
    ma : numpy array
        numpy array The data on this matrix is summarized according
        to the `average_type` argument.
    average_type : str
        string values are sum mean median min max std
    color : str
        a valid color: either a html color name, hex
        (e.g #002233), RGB + alpha tuple or list or RGB tuple or list
    label : str
        label
    plot_type: str
        type of plot. Either 'se' for standard error, 'std' for
        standard deviation, 'overlapped_lines' to plot each line of the matrix,
        fill to plot the area between the x axis and the value or any other string to
        just plot the average line.

    Returns
    -------
    ax
        matplotlib axis

    Examples
    --------

    >>> import matplotlib.pyplot as plt
    >>> import os
    >>> fig = plt.figure()
    >>> ax = fig.add_subplot(111)
    >>> matrix = np.array([[1,2,3],
    ...                    [4,5,6],
    ...                    [7,8,9]])
    >>> ax = plot_single(ax, matrix -2, 'mean', color=[0.6, 0.8, 0.9], label='fill light blue', plot_type='fill')
    >>> ax = plot_single(ax, matrix, 'mean', color='blue', label='red')
    >>> ax = plot_single(ax, matrix + 5, 'mean', color='red', label='red', plot_type='std')
    >>> ax = plot_single(ax, matrix + 10, 'mean', color='#cccccc', label='gray se', plot_type='se')
    >>> ax = plot_single(ax, matrix + 20, 'mean', color=(0.9, 0.5, 0.9), label='violet', plot_type='std')
    >>> ax = plot_single(ax, matrix + 30, 'mean', color=(0.9, 0.5, 0.9, 0.5), label='violet with alpha', plot_type='std')
    >>> leg = ax.legend()
    >>> plt.savefig("/tmp/test.pdf")
    >>> plt.close()
    >>> fig = plt.figure()
    >>> os.remove("/tmp/test.pdf")


    """
    summary = np.ma.__getattribute__(average_type)(ma, axis=0)
    # only plot the average profiles without error regions
    x = np.arange(len(summary))
    if isinstance(color, np.ndarray):
        color = pltcolors.to_hex(color, keep_alpha=True)
    ax.plot(x, summary, color=color, label=label, alpha=0.9)
    if plot_type == 'fill':
        ax.fill_between(x,
                        summary,
                        facecolor=color,
                        alpha=0.6,
                        edgecolor='none')

    if plot_type in ['se', 'std']:
        if plot_type == 'se':  # standard error
            std = np.std(ma, axis=0) / np.sqrt(ma.shape[0])
        else:
            std = np.std(ma, axis=0)

        alpha = 0.2
        # an alpha channel has to be added to the color to fill the area
        # between the mean (or median etc.) and the std or se
        f_color = pltcolors.colorConverter.to_rgba(color, alpha)

        ax.fill_between(x,
                        summary,
                        summary + std,
                        facecolor=f_color,
                        edgecolor='none')
        ax.fill_between(x,
                        summary,
                        summary - std,
                        facecolor=f_color,
                        edgecolor='none')

    ax.set_xlim(0, max(x))

    return ax
Example #30
0
def define_color_cycler_from_map(n, colormap=None):
    if colormap is None:
        colormap = default_color_map
    return cycler(
        color=[to_hex(colormap(float(i) / float(n))) for i in range(n)])
Example #31
0
def state_graph(adata,
                group,
                basis="umap",
                x=0,
                y=1,
                color='ntr',
                layer="X",
                highlights=None,
                labels=None,
                values=None,
                theme=None,
                cmap=None,
                color_key=None,
                color_key_cmap=None,
                background=None,
                ncols=1,
                pointsize=None,
                figsize=(6, 4),
                show_legend=True,
                use_smoothed=True,
                show_arrowed_spines=True,
                ax=None,
                sort='raw',
                frontier=False,
                save_show_or_return="show",
                save_kwargs={},
                s_kwargs_dict={},
                **kwargs):
    """Plot a summarized cell type (state) transition graph. This function tries to create a model that summarizes
    the possible cell type transitions based on the reconstructed vector field function.

    Parameters
    ----------
        group: `str` or `None` (default: `None`)
            The column in adata.obs that will be used to aggregate data points for the purpose of creating a cell type
            transition model.
        %(scatters.parameters.no_aggregate|kwargs|save_kwargs)s
        save_kwargs: `dict` (default: `{}`)
            A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function
            will use the {"path": None, "prefix": 'state_graph', "dpi": None, "ext": 'pdf', "transparent": True, "close":
            True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modify those keys
            according to your needs.
        s_kwargs_dict: `dict` (default: {})
            The dictionary of the scatter arguments.
    Returns
    -------
        Plot the a model of cell fate transition that summarizes the possible lineage commitments between different cell
        types.
    """

    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    from matplotlib.colors import to_hex

    aggregate = group

    points = adata.obsm["X_" + basis][:, [x, y]]
    groups, uniq_grp = adata.obs[group], adata.obs[group].unique().to_list()
    group_median = np.zeros((len(uniq_grp), 2))
    grp_size = adata.obs[group].value_counts().values
    s_kwargs_dict.update({"s": grp_size})

    Pl = adata.uns["Cell type annotation_graph"]["group_graph"]
    Pl[Pl - Pl.T < 0] = 0
    Pl /= Pl.sum(1)[:, None]

    for i, cur_grp in enumerate(uniq_grp):
        group_median[i, :] = np.nanmedian(
            points[np.where(groups == cur_grp)[0], :2], 0)

    if background is None:
        _background = rcParams.get("figure.facecolor")
        background = to_hex(
            _background) if type(_background) is tuple else _background

    plt.figure(facecolor=_background)
    axes_list, color_list, font_color = scatters(
        adata=adata,
        basis=basis,
        x=x,
        y=y,
        color=color,
        layer=layer,
        highlights=highlights,
        labels=labels,
        values=values,
        theme=theme,
        cmap=cmap,
        color_key=color_key,
        color_key_cmap=color_key_cmap,
        background=background,
        ncols=ncols,
        pointsize=pointsize,
        figsize=figsize,
        show_legend=show_legend,
        use_smoothed=use_smoothed,
        aggregate=aggregate,
        show_arrowed_spines=show_arrowed_spines,
        ax=ax,
        sort=sort,
        save_show_or_return='return',
        frontier=frontier,
        **s_kwargs_dict,
        return_all=True,
    )

    arrows = create_edge_patches_from_markov_chain(Pl,
                                                   group_median,
                                                   tol=0.01,
                                                   node_rad=15)
    if type(axes_list) == list:
        for i in range(len(axes_list)):
            for arrow in arrows:
                axes_list[i].add_patch(arrow)
                axes_list[i].set_facecolor(background)
    else:
        for arrow in arrows:
            axes_list.add_patch(arrow)
            axes_list.set_facecolor(background)

    plt.axis("off")

    plt.show()

    if save_show_or_return == "save":
        s_kwargs = {
            "path": None,
            "prefix": 'state_graph',
            "dpi": None,
            "ext": 'pdf',
            "transparent": True,
            "close": True,
            "verbose": True
        }
        s_kwargs = update_dict(s_kwargs, save_kwargs)

        save_fig(**s_kwargs)
    elif save_show_or_return == "show":
        if show_legend:
            plt.subplots_adjust(right=0.85)
        plt.tight_layout()
        plt.show()
    elif save_show_or_return == "return":
        return axes_list, color_list, font_color
Example #32
0
 def update_text(self, color):
     self.lineedit.setText(mcolors.to_hex(color.getRgbF(), keep_alpha=True))
Example #33
0
    "hydro-rsv": "magenta",
    "hydrogen-storage": "pink",
    "lithium-battery": "salmon",
    "waste-st": "yellowgreen",
    "oil-ocgt": "black",
    "other": "red",
    "other-res": "orange",
    "electricity-load": "slategray",
    "import": "mediumpurple",
    "storage": "plum",
    "mixed-st": "chocolate",
    "decentral_heat-gshp": "darkcyan",
    "flex-decentral_heat-gshp": "darkcyan",
    "fossil": "darkgray",
}
color_dict = {name: colors.to_hex(color) for name, color in color.items()}

path = os.path.join(os.getcwd(), "results")

renewables = [
    "hydro-ror",
    "hydro-reservoir",
    "wind-offshore",
    "wind-onshore",
    "solar-pv",
    "other-res",
    "biomass-st",
]
storages = [
    "hydrogen-storage",
    "redox-battery",
Example #34
0
def test_xkcd():
    assert mcolors.to_hex("blue") == "#0000ff"
    assert mcolors.to_hex("xkcd:blue") == "#0343df"
 def update_text(self, color):
     self.lineedit.setText(mcolors.to_hex(color.getRgbF(), keep_alpha=True))
Example #36
0
def _set_colors_for_categorical_obs(adata, value_to_plot, palette):
    """
    Sets the adata.uns[value_to_plot + '_colors'] according to the given palette

    Parameters
    ----------
    adata : annData object
    value_to_plot : name of a valid categorical observation
    palette : Palette should be either a valid `matplotlib.pyplot.colormaps()` string,
              a list of colors (in a format that can be understood by matplotlib,
              eg. RGB, RGBS, hex, or a cycler object with key='color'

    Returns
    -------
    None
    """
    from matplotlib.colors import to_hex
    from cycler import Cycler, cycler

    categories = adata.obs[value_to_plot].cat.categories
    # check is palette is a valid matplotlib colormap
    if isinstance(palette, str) and palette in pl.colormaps():
        # this creates a palette from a colormap. E.g. 'Accent, Dark2, tab20'
        cmap = pl.get_cmap(palette)
        colors_list = [
            to_hex(x) for x in cmap(np.linspace(0, 1, len(categories)))
        ]

    else:
        # check if palette is a list and convert it to a cycler, thus
        # it doesnt matter if the list is shorter than the categories length:
        if isinstance(palette, list):
            if len(palette) < len(categories):
                logg.warn(
                    "Length of palette colors is smaller than the number of "
                    "categories (palette length: {}, categories length: {}. "
                    "Some categories will have the same color.".format(
                        len(palette), len(categories)))
            # check that colors are valid
            _color_list = []
            for color in palette:
                if not is_color_like(color):
                    # check if the color is a valid R color and translate it
                    # to a valid hex color value
                    if color in utils.additional_colors:
                        color = utils.additional_colors[color]
                    else:
                        raise ValueError(
                            "The following color value of the given palette is not valid: {}"
                            .format(color))
                _color_list.append(color)

            palette = cycler(color=_color_list)
        if not isinstance(palette, Cycler):
            raise ValueError(
                "Please check that the value of 'palette' is a "
                "valid matplotlib colormap string (eg. Set2), a "
                "list of color names or a cycler with a 'color' key.")
        if 'color' not in palette.keys:
            raise ValueError("Please set the palette key 'color'.")

        cc = palette()
        colors_list = [
            to_hex(next(cc)['color']) for x in range(len(categories))
        ]

    adata.uns[value_to_plot + '_colors'] = colors_list
Example #37
0
def color_fader(color_1: str, color_2: str, mix: float = 0.) -> str:
    c1 = np.array(mpl_colors.to_rgb(color_1))
    c2 = np.array(mpl_colors.to_rgb(color_2))
    return mpl_colors.to_hex((1 - mix) * c1 + mix * c2)
Example #38
0
def test_color_names():
    assert mcolors.to_hex("blue") == "#0000ff"
    assert mcolors.to_hex("xkcd:blue") == "#0343df"
    assert mcolors.to_hex("tab:blue") == "#1f77b4"
Example #39
0
def gmtColormap_openfile(cptf,
                         name=None,
                         method='cdict',
                         N=None,
                         ret_cmap_type='LinearSegmented'):
    """Read a GMT color map from an OPEN cpt file
    Edited by: bouziot, 2020.03

    Parameters
    ----------
    cptf : str, open file or url handle
        path to .cpt file

    name : str, optional
        name for color map
        if not provided, the file name will be used

    method : str, suggests the method to use.
    If method = 'cdict', generates the LinearSegmentedColormap using a color dictionary (cdict), disregarding any value in N.
    If method = 'list', generates the LinearSegmentedColor using the .from_list() method, passing a list of (value, (r,g,b)) tuples obtained from the cpt file. This allows the selection of colormap resolution by the user, using the N parameter

    N : int, the number of colors in the colormap. Only useful when method='list'.

    ret_cmap_type: str, the type of matplotlib cmap object to be returned. Accepts either 'LinearSegmented', which returns a matplotlib.colors.LinearSegmentedColormap, or 'Listed', which returns a ListedColormap
    In case 'Listed' is selected, the method argument from the user is ignored and method is set to 'list' ('Linear' doesn't work with 'cdict').
    N is then passed to matplotlib.colors.ListedColormap().
    - If N is set to None: all colors of the cpt file will be returned as a list.
    - In case of a user-defined N, colors will be truncated or extended by repetition (see matplotlib.colors.ListedColormap).

    Returns
    -------
    a matplotlib colormap object (matplotlib.colors.LinearSegmentedColormap or matplotlib.colors.ListedColormap)

    Credits
    -------
    This function originally appears in pycpt, extensive edits from bouziot, 2020.03
    Original work in: https://github.com/j08lue/pycpt
    LOG OF EDITS (2020.03):
        - Fixed bug when parsing non-split '#' lines in .cpt files
        - Fixed bug - not identifying the colorModel '#' line correctly
        - Fixed binary comparison performance (introduced in python 3)
        - Added functionality to return ListedColormaps and cmaps with custom colors (method, ret_cmap_type args)
        - Added global name handling externally (_getname() func)
    """
    methodnames = ['cdict', 'list']  # accepted method arguments
    ret_cmap_types = ['LinearSegmented', 'Listed',
                      'raw']  # accepted return matplotlib colormap types

    # generate cmap name
    if name is None:
        name = _getname(cptf.name)
    #    name = '_'.join(os.path.basename(cptf.name).split('.')[:-1])

    # process file
    x = []
    r = []
    g = []
    b = []
    lastls = None
    for l in cptf.readlines():
        ls = l.split()

        # skip empty lines
        if not ls:
            continue

        # parse header info
        # this leads to mistakes in some files...
        # if ls[0] in ["#", b"#"]:  # '#' is not always separated from other letters in some cases...
        #    if ls[-1] in ["HSV", b"HSV"]:
        #        colorModel = "HSV"
        #    else:
        #        colorModel = "RGB"
        #    continue

        # byte comparison is not feasible in python 3
        if (isinstance(l, bytes) and l.decode('utf-8')[0] in ["#", b"#"]) or (
                isinstance(l, str) and l[0] in ["#", b"#"]):
            if ls[-1] in ["HSV", b"HSV"]:
                colorModel = "HSV"
                continue
            elif ls[-1] in ["RGB", b"RGB"]:
                colorModel = "RGB"
                continue
            else:  # case rogue comment, ignore
                continue

        # skip BFN info
        if ls[0] in ["B", b"B", "F", b"F", "N", b"N"]:
            continue

        # parse color vectors
        x.append(float(ls[0]))
        r.append(float(ls[1]))
        g.append(float(ls[2]))
        b.append(float(ls[3]))

        # save last row
        lastls = ls

    # check if last endrow has the same color, if not, append
    if not ((float(lastls[5]) == r[-1]) and (float(lastls[6]) == g[-1]) and
            (float(lastls[7]) == b[-1])):
        x.append(float(lastls[4]))
        r.append(float(lastls[5]))
        g.append(float(lastls[6]))
        b.append(float(lastls[7]))

    x = np.array(x)
    r = np.array(r)
    g = np.array(g)
    b = np.array(b)

    if colorModel == "HSV":
        for i in range(r.shape[0]):
            # convert HSV to RGB
            rr, gg, bb = colorsys.hsv_to_rgb(r[i] / 360., g[i], b[i])
            r[i] = rr
            g[i] = gg
            b[i] = bb
    elif colorModel == "RGB":
        r /= 255.
        g /= 255.
        b /= 255.

    red = []
    blue = []
    green = []
    xNorm = (x - x[0]) / (x[-1] - x[0])

    # return colormap
    if method == 'cdict' and ret_cmap_type == 'LinearSegmented':
        # generate cdict
        for i in range(len(x)):
            red.append([xNorm[i], r[i], r[i]])
            green.append([xNorm[i], g[i], g[i]])
            blue.append([xNorm[i], b[i], b[i]])
        cdict = dict(red=red, green=green, blue=blue)
        #return cdict
        return mcolors.LinearSegmentedColormap(name=name, segmentdata=cdict)

    elif method == 'list' and ret_cmap_type == 'LinearSegmented':
        # generate list of values in the form of (value, c)
        outlist = []
        for i in range(len(x)):
            tup = (xNorm[i], (r[i], g[i], b[i]))
            outlist.append(tup)

        if N and type(N) == int:
            #return outlist
            return mcolors.LinearSegmentedColormap.from_list(name,
                                                             outlist,
                                                             N=N)
        else:
            raise TypeError(
                "Using the method 'list' requires you to set a number of colors N."
            )

    elif ret_cmap_type == 'Listed':
        # generate list of values and return it as ListedColormap
        # returns both colors and the normalized positions (pos) where colors change, in the form of two outputs pos, colors
        pos_out = []
        colors_out = []
        for i in range(len(x)):
            pos_out.append(xNorm[i])  # list of positions
            colors_out.append(mcolors.to_hex(
                (r[i], g[i], b[i])))  # list of colors
        # return pos, color pairs
        print(colors_out)
        if N and type(N) == int and N <= len(colors_out):
            pos_out = pos_out[:N]  #truncate positions to N
            return pos_out, mcolors.ListedColormap(colors_out, name=name, N=N)
        elif N is None:
            return pos_out, mcolors.ListedColormap(colors_out, name=name)
        else:
            raise TypeError(
                "N has to be a number of colors that is less than the actual colors found in the .cpt file ("
                + str(len(colors_out)) + " colors found).")
    elif ret_cmap_type == 'raw':
        pos_out = []
        colors_out = []
        for i in range(len(x)):
            pos_out.append(xNorm[i])  # list of positions
            colors_out.append((r[i] * 255, g[i] * 255, b[i] * 255))
        return pos_out, colors_out
    else:
        raise TypeError("method has to be one of the arguments: " +
                        str(methodnames) +
                        " and ret_cmap_type has to be one of the arguments: " +
                        str(ret_cmap_types))
def fert_rate_color(x):
    if pd.isnull(x):
        return '#FFFFFF'  # white
    else:
        return colors.to_hex(cmap_meanfertrate(x))  # color
def sensitivity(adata,
                regulators=None,
                effectors=None,
                basis="umap",
                skey='sensitivity',
                s_basis="pca",
                x=0,
                y=1,
                layer='M_s',
                highlights=None,
                cmap='bwr',
                background=None,
                pointsize=None,
                figsize=(6, 4),
                show_legend=True,
                frontier=True,
                sym_c=True,
                sort='abs',
                show_arrowed_spines=False,
                stacked_fraction=False,
                save_show_or_return="show",
                save_kwargs={},
                **kwargs):
    """\
    Scatter plot of Sensitivity value across cells.

    Parameters
    ----------
        adata: :class:`~anndata.AnnData`
            an Annodata object with Jacobian matrix estimated.
        regulators: `list` or `None` (default: `None`)
            The list of genes that will be used as regulators for plotting the Jacobian heatmap, only limited to genes
            that have already performed Jacobian analysis.
        effectors: `List` or `None` (default: `None`)
            The list of genes that will be used as targets for plotting the Jacobian heatmap, only limited to genes
            that have already performed Jacobian analysis.
        basis: `str` (default: `umap`)
            The reduced dimension basis.
        skey: `str` (default: `sensitivity`)
            The key to the sensitivity dictionary in .uns.
        s_basis: `str` (default: `pca`)
            The reduced dimension space that will be used to calculate the jacobian matrix.
        x: `int` (default: `0`)
            The column index of the low dimensional embedding for the x-axis.
        y: `int` (default: `1`)
            The column index of the low dimensional embedding for the y-axis.
        highlights: `list` (default: None)
            Which color group will be highlighted. if highligts is a list of lists - each list is relate to each color element.
        cmap: string (optional, default 'Blues')
            The name of a matplotlib colormap to use for coloring
            or shading points. If no labels or values are passed
            this will be used for shading points according to
            density (largely only of relevance for very large
            datasets). If values are passed this will be used for
            shading according the value. Note that if theme
            is passed then this value will be overridden by the
            corresponding option of the theme.
        background: string or None (optional, default 'None`)
            The color of the background. Usually this will be either
            'white' or 'black', but any color name will work. Ideally
            one wants to match this appropriately to the colors being
            used for points etc. This is one of the things that themes
            handle for you. Note that if theme
            is passed then this value will be overridden by the
            corresponding option of the theme.
        figsize: `None` or `[float, float]` (default: (6, 4))
                The width and height of each panel in the figure.
        show_legend: bool (optional, default True)
            Whether to display a legend of the labels
        frontier: `bool` (default: `False`)
            Whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to show area
            of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib tips & tricks
            cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from scEU-seq paper:
            https://science.sciencemag.org/content/367/6482/1151.
        sym_c: `bool` (default: `True`)
            Whether do you want to make the limits of continuous color to be symmetric, normally this should be used for
            plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative values.
        sort: `str` (optional, default `abs`)
            The method to reorder data so that high values points will be on top of background points. Can be one of
            {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values.
        show_arrowed_spines: bool (optional, default False)
            Whether to show a pair of arrowed spines representing the basis of the scatter is currently using.
        stacked_fraction: bool (default: False)
            If True the jacobian will be represented as a stacked fraction in the title, otherwise a linear fraction
            style is used.
        save_show_or_return: `str` {'save', 'show', 'return'} (default: `show`)
            Whether to save, show or return the figure.
        save_kwargs: `dict` (default: `{}`)
            A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig
            function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True,
            "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly
            modify those keys according to your needs.
        kwargs:
            Additional arguments passed to plt._matplotlib_points.

    Returns
    -------
    Nothing but plots the n_source x n_targets scatter plots of low dimensional embedding of the adata object, each
    corresponds to one element in the Jacobian matrix for all sampled cells.

    Examples
    --------
    >>> import dynamo as dyn
    >>> adata = dyn.sample_data.hgForebrainGlutamatergic()
    >>> adata = dyn.pp.recipe_monocle(adata)
    >>> dyn.tl.dynamics(adata)
    >>> dyn.vf.VectorField(adata, basis='pca')
    >>> valid_gene_list = adata[:, adata.var.use_for_transition].var.index[:2]
    >>> dyn.vf.sensitivity(adata, regulators=valid_gene_list[0], effectors=valid_gene_list[1])
    >>> dyn.pl.sensitivity(adata)
    """

    regulators, effectors = list(np.unique(regulators)) if regulators is not None else None, \
                            list(np.unique(effectors)) if effectors is not None else None

    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    from matplotlib.colors import to_hex

    if background is None:
        _background = rcParams.get("figure.facecolor")
        _background = to_hex(
            _background) if type(_background) is tuple else _background
    else:
        _background = background

    Sensitivity_ = skey if s_basis is None else skey + "_" + s_basis
    Der, cell_indx, sensitivity_gene, regulators_, effectors_ = adata.uns[Sensitivity_].get(skey.split("_")[-1]), \
                                                             adata.uns[Sensitivity_].get('cell_idx'), \
                                                             adata.uns[Sensitivity_].get(skey.split("_")[-1] + '_gene'), \
                                                             adata.uns[Sensitivity_].get('regulators'), \
                                                             adata.uns[Sensitivity_].get('effectors')

    adata_ = adata[cell_indx, :]

    # test the simulation data here
    if (regulators_ is None or effectors_ is None):
        if Der.shape[0] != adata_.n_vars:
            source_genes = [
                s_basis + '_' + str(i) for i in range(Der.shape[0])
            ]
            target_genes = [
                s_basis + '_' + str(i) for i in range(Der.shape[1])
            ]
        else:
            source_genes, target_genes = adata_.var_names, adata_.var_names
    else:
        Der, source_genes, target_genes = intersect_sources_targets(
            regulators, regulators_, effectors, effectors_,
            Der if sensitivity_gene is None else sensitivity_gene)

    ## integrate this with the code in scatter ##

    if type(x) is int and type(y) is int:
        prefix = 'X_'
        cur_pd = pd.DataFrame({
            basis + "_" + str(x):
            adata_.obsm[prefix + basis][:, x],
            basis + "_" + str(y):
            adata_.obsm[prefix + basis][:, y],
        })
    elif is_gene_name(adata_, x) and is_gene_name(adata_, y):
        cur_pd = pd.DataFrame({
            x:
            adata_.obs_vector(k=x, layer=None)
            if layer == 'X' else adata_.obs_vector(k=x, layer=layer),
            y:
            adata_.obs_vector(k=y, layer=None)
            if layer == 'X' else adata_.obs_vector(k=y, layer=layer),
        })
        # cur_pd = cur_pd.loc[(cur_pd > 0).sum(1) > 1, :]
        cur_pd.columns = [
            x + " (" + layer + ")",
            y + " (" + layer + ")",
        ]
    elif is_cell_anno_column(adata_, x) and is_cell_anno_column(adata_, y):
        cur_pd = pd.DataFrame({
            x: adata_.obs_vector(x),
            y: adata_.obs_vector(y),
        })
        cur_pd.columns = [x, y]
    elif is_cell_anno_column(adata_, x) and is_gene_name(adata_, y):
        cur_pd = pd.DataFrame({
            x:
            adata_.obs_vector(x),
            y:
            adata_.obs_vector(k=y, layer=None)
            if layer == 'X' else adata_.obs_vector(k=y, layer=layer),
        })
        cur_pd.columns = [x, y + " (" + layer + ")"]
    elif is_gene_name(adata_, x) and is_cell_anno_column(adata_, y):
        cur_pd = pd.DataFrame({
            x:
            adata_.obs_vector(k=x, layer=None)
            if layer == 'X' else adata_.obs_vector(k=x, layer=layer),
            y:
            adata_.obs_vector(y)
        })
        # cur_pd = cur_pd.loc[cur_pd.iloc[:, 0] > 0, :]
        cur_pd.columns = [x + " (" + layer + ")", y]
    elif is_layer_keys(adata_, x) and is_layer_keys(adata_, y):
        x_, y_ = adata_[:, basis].layers[x], adata_[:, basis].layers[y]
        cur_pd = pd.DataFrame({x: flatten(x_), y: flatten(y_)})
        # cur_pd = cur_pd.loc[cur_pd.iloc[:, 0] > 0, :]
        cur_pd.columns = [x, y]
    elif type(x) in [anndata._core.views.ArrayView, np.ndarray] and \
            type(y) in [anndata._core.views.ArrayView, np.ndarray]:
        cur_pd = pd.DataFrame({'x': flatten(x), 'y': flatten(y)})
        cur_pd.columns = ['x', 'y']

    point_size = (500.0 /
                  np.sqrt(adata_.shape[0]) if pointsize is None else 500.0 /
                  np.sqrt(adata_.shape[0]) * pointsize)
    point_size = 4 * point_size

    scatter_kwargs = dict(
        alpha=0.2,
        s=point_size,
        edgecolor=None,
        linewidth=0,
    )  # (0, 0, 0, 1)
    if kwargs is not None:
        scatter_kwargs.update(kwargs)

    nrow, ncol = len(source_genes), len(target_genes)
    if figsize is None:
        g = plt.figure(None, (3 * ncol, 3 * nrow),
                       facecolor=_background)  # , dpi=160
    else:
        g = plt.figure(None, (figsize[0] * ncol, figsize[1] * nrow),
                       facecolor=_background)  # , dpi=160

    gs = plt.GridSpec(nrow, ncol, wspace=0.12)

    for i, source in enumerate(source_genes):
        for j, target in enumerate(target_genes):
            ax = plt.subplot(gs[i * ncol + j])
            S = Der[j, i, :]  # dim 0: target; dim 1: source
            cur_pd["sensitivity"] = S

            # cur_pd.loc[:, "sensitivity"] = np.array([scinot(i) for i in cur_pd.loc[:, "jacobian"].values])
            v_max = np.max(np.abs(S))
            scatter_kwargs.update({"vmin": -v_max, "vmax": v_max})
            ax, color = _matplotlib_points(cur_pd.iloc[:, [0, 1]].values,
                                           ax=ax,
                                           labels=None,
                                           values=S,
                                           highlights=highlights,
                                           cmap=cmap,
                                           color_key=None,
                                           color_key_cmap=None,
                                           background=_background,
                                           width=figsize[0],
                                           height=figsize[1],
                                           show_legend=show_legend,
                                           frontier=frontier,
                                           sort=sort,
                                           sym_c=sym_c,
                                           **scatter_kwargs)
            if stacked_fraction:
                ax.set_title(r'$\frac{d x_{%s}}{d x_{%s}}$' % (target, source))
            else:
                ax.set_title(r'$d x_{%s} / d x_{%s}$' % (target, source))
            if i + j == 0 and show_arrowed_spines:
                arrowed_spines(ax, basis, background)
            else:
                despline_all(ax)
                deaxis_all(ax)

    if save_show_or_return == "save":
        s_kwargs = {
            "path": None,
            "prefix": skey,
            "dpi": None,
            "ext": 'pdf',
            "transparent": True,
            "close": True,
            "verbose": True
        }
        s_kwargs = update_dict(s_kwargs, save_kwargs)

        save_fig(**s_kwargs)
    elif save_show_or_return == "show":
        plt.tight_layout()
        plt.show()
    elif save_show_or_return == "return":
        return gs
Example #42
0
def convert_rgb_to_hex(rgb_col):
    hex_col = hex(int(
        to_hex(rgb_col, keep_alpha=False).replace('#', '0x'), 16))
    #print('convert rgb to hex ', rgb_col, ' >> ', hex_col)
    return hex_col
Example #43
0
    dt = 0.001
    t0 = 2019.8
    tmax = 2021.5
    t, populations = trajectory(initial_population,
                                t0,
                                tmax,
                                dt,
                                params,
                                resampling_interval=0,
                                turnover=0)

    from matplotlib.cm import plasma
    from matplotlib.colors import to_hex
    colors = [
        'C0',
        to_hex(plasma(0.1)),
        to_hex(plasma(0.5)),
        to_hex(plasma(0.9))
    ]

    fs = 16
    plt.figure()
    plt.plot(t,
             populations[:, 0, 2] * params[0, 0],
             lw=3,
             label='Hubei',
             ls='--',
             c=colors[0])

    for pi in range(1, len(params)):
        plt.plot(t,
Example #44
0
 def _get_color(self):
     if isnull(self._value):
         return
     color_coord = max(0, min(1 - 1e-9, self._value / self._max_value))
     return to_hex(self._cmap(color_coord))
Example #45
0
def plot_loo_pit(
    ax,
    figsize,
    ecdf,
    loo_pit,
    loo_pit_ecdf,
    unif_ecdf,
    p975,
    p025,
    fill_kwargs,
    ecdf_fill,
    use_hdi,
    x_vals,
    hdi_kwargs,
    hdi_odds,
    n_unif,
    unif,
    plot_unif_kwargs,
    loo_pit_kde,
    legend,  # pylint: disable=unused-argument
    y_hat,
    y,
    color,
    textsize,
    credible_interval,
    plot_kwargs,
    backend_kwargs,
    show,
):
    """Bokeh loo pit plot."""
    if backend_kwargs is None:
        backend_kwargs = {}

    backend_kwargs = {
        **backend_kwarg_defaults(("dpi", "plot.bokeh.figure.dpi"),),
        **backend_kwargs,
    }
    dpi = backend_kwargs.pop("dpi")

    (figsize, *_, linewidth, _) = _scale_fig_size(figsize, textsize, 1, 1)

    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
    plot_kwargs.setdefault("color", to_hex(color))
    plot_kwargs.setdefault("linewidth", linewidth * 1.4)
    if isinstance(y, str):
        label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y)
    elif isinstance(y, DataArray) and y.name is not None:
        label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y.name)
    elif isinstance(y_hat, str):
        label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat)
    elif isinstance(y_hat, DataArray) and y_hat.name is not None:
        label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat.name)
    else:
        label = "LOO-PIT ECDF" if ecdf else "LOO-PIT"

    plot_kwargs.setdefault("legend_label", label)

    plot_unif_kwargs = {} if plot_unif_kwargs is None else plot_unif_kwargs
    light_color = rgb_to_hsv(to_rgb(plot_kwargs.get("color")))
    light_color[1] /= 2  # pylint: disable=unsupported-assignment-operation
    light_color[2] += (1 - light_color[2]) / 2  # pylint: disable=unsupported-assignment-operation
    plot_unif_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
    plot_unif_kwargs.setdefault("alpha", 0.5)
    plot_unif_kwargs.setdefault("linewidth", 0.6 * linewidth)

    if ecdf:
        n_data_points = loo_pit.size
        plot_kwargs.setdefault("drawstyle", "steps-mid" if n_data_points < 100 else "default")
        plot_unif_kwargs.setdefault("drawstyle", "steps-mid" if n_data_points < 100 else "default")

        if ecdf_fill:
            if fill_kwargs is None:
                fill_kwargs = {}
            fill_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
            fill_kwargs.setdefault("alpha", 0.5)
            fill_kwargs.setdefault(
                "step", "mid" if plot_kwargs["drawstyle"] == "steps-mid" else None
            )
            fill_kwargs.setdefault(
                "legend_label", "{:.3g}% credible interval".format(credible_interval)
            )
    elif use_hdi:
        if hdi_kwargs is None:
            hdi_kwargs = {}
        hdi_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
        hdi_kwargs.setdefault("alpha", 0.35)

    if ax is None:
        backend_kwargs.setdefault("width", int(figsize[0] * dpi))
        backend_kwargs.setdefault("height", int(figsize[1] * dpi))
        ax = bkp.figure(x_range=(0, 1), **backend_kwargs)

    if ecdf:
        if plot_kwargs.get("drawstyle") == "steps-mid":
            ax.step(
                np.hstack((0, loo_pit, 1)),
                np.hstack((0, loo_pit - loo_pit_ecdf, 0)),
                line_color=plot_kwargs.get("color", "black"),
                line_alpha=plot_kwargs.get("alpha", 1.0),
                line_width=plot_kwargs.get("linewidth", 3.0),
                mode="center",
            )
        else:
            ax.line(
                np.hstack((0, loo_pit, 1)),
                np.hstack((0, loo_pit - loo_pit_ecdf, 0)),
                line_color=plot_kwargs.get("color", "black"),
                line_alpha=plot_kwargs.get("alpha", 1.0),
                line_width=plot_kwargs.get("linewidth", 3.0),
            )

        if ecdf_fill:
            if fill_kwargs.get("drawstyle") == "steps-mid":
                # use step patch when you find out how to do that
                ax.patch(
                    np.concatenate((unif_ecdf, unif_ecdf[::-1])),
                    np.concatenate((p975 - unif_ecdf, (p025 - unif_ecdf)[::-1])),
                    fill_color=fill_kwargs.get("color"),
                    fill_alpha=fill_kwargs.get("alpha", 1.0),
                )
            else:
                ax.patch(
                    np.concatenate((unif_ecdf, unif_ecdf[::-1])),
                    np.concatenate((p975 - unif_ecdf, (p025 - unif_ecdf)[::-1])),
                    fill_color=fill_kwargs.get("color"),
                    fill_alpha=fill_kwargs.get("alpha", 1.0),
                )
        else:
            if fill_kwargs is not None and fill_kwargs.get("drawstyle") == "steps-mid":
                ax.step(
                    unif_ecdf,
                    p975 - unif_ecdf,
                    line_color=plot_unif_kwargs.get("color", "black"),
                    line_alpha=plot_unif_kwargs.get("alpha", 1.0),
                    line_width=plot_kwargs.get("linewidth", 1.0),
                    mode="center",
                )
                ax.step(
                    unif_ecdf,
                    p025 - unif_ecdf,
                    line_color=plot_unif_kwargs.get("color", "black"),
                    line_alpha=plot_unif_kwargs.get("alpha", 1.0),
                    line_width=plot_unif_kwargs.get("linewidth", 1.0),
                    mode="center",
                )
            else:
                ax.line(
                    unif_ecdf,
                    p975 - unif_ecdf,
                    line_color=plot_unif_kwargs.get("color", "black"),
                    line_alpha=plot_unif_kwargs.get("alpha", 1.0),
                    line_width=plot_unif_kwargs.get("linewidth", 1.0),
                )
                ax.line(
                    unif_ecdf,
                    p025 - unif_ecdf,
                    line_color=plot_unif_kwargs.get("color", "black"),
                    line_alpha=plot_unif_kwargs.get("alpha", 1.0),
                    line_width=plot_unif_kwargs.get("linewidth", 1.0),
                )
    else:
        if use_hdi:
            ax.add_layout(
                BoxAnnotation(
                    bottom=hdi_odds[1],
                    top=hdi_odds[0],
                    fill_alpha=hdi_kwargs.pop("alpha"),
                    fill_color=hdi_kwargs.pop("color"),
                    **hdi_kwargs
                )
            )
        else:
            for idx in range(n_unif):
                x_s, unif_density = _kde(unif[idx, :])
                ax.line(
                    x_s,
                    unif_density,
                    line_color=plot_unif_kwargs.get("color", "black"),
                    line_alpha=plot_unif_kwargs.get("alpha", 0.1),
                    line_width=plot_unif_kwargs.get("linewidth", 1.0),
                )
        ax.line(
            x_vals,
            loo_pit_kde,
            line_color=plot_kwargs.get("color", "black"),
            line_alpha=plot_kwargs.get("alpha", 1.0),
            line_width=plot_kwargs.get("linewidth", 3.0),
        )

    show_layout(ax, show)

    return ax
Example #46
0
    def plot_pca(self,
                 plot_filename=None,
                 PCs=[1, 2],
                 plot_title='',
                 image_format=None,
                 log1p=False,
                 plotWidth=5,
                 plotHeight=10,
                 cols=None,
                 marks=None):
        """
        Plot the PCA of a matrix

        Returns the matrix of plotted values.
        """
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(plotWidth, plotHeight))

        # Filter
        m = self.matrix
        rvs = m.var(axis=1)
        if self.transpose:
            m = m[np.nonzero(rvs)[0], :]
            rvs = rvs[np.nonzero(rvs)[0]]
        if self.ntop > 0 and m.shape[0] > self.ntop:
            m = m[np.argpartition(rvs, -self.ntop)[-self.ntop:], :]
            rvs = rvs[np.argpartition(rvs, -self.ntop)[-self.ntop:]]

        # log2 (if requested)
        if self.log2:
            self.matrix = np.log2(self.matrix + 0.01)

        # Row center / transpose
        if self.rowCenter and not self.transpose:
            _ = self.matrix.mean(axis=1)
            self.matrix -= _[:, None]
        if self.transpose:
            m = m.T

        # Center and scale
        m2 = (m - np.mean(m, axis=0))
        m2 /= np.std(m2, axis=0, ddof=1)  # Use the unbiased std. dev.

        # SVD
        U, s, Vh = np.linalg.svd(
            m2, full_matrices=False,
            compute_uv=True)  # Is full_matrices ever needed?

        # % variance, eigenvalues
        eigenvalues = s**2
        variance = eigenvalues / float(np.max([1, m2.shape[1] - 1]))
        pvar = variance / variance.sum()

        # Weights/projections
        Wt = Vh
        if self.transpose:
            # Use the projected coordinates for the transposed matrix
            Wt = np.dot(m2, Vh.T).T

        if plot_filename is not None:
            n = n_bars = len(self.labels)
            if eigenvalues.size < n:
                n_bars = eigenvalues.size
            markers = itertools.cycle(
                matplotlib.markers.MarkerStyle.filled_markers)
            if cols is not None:
                colors = itertools.cycle(cols)
            else:
                colors = itertools.cycle(
                    plt.cm.gist_rainbow(np.linspace(0, 1, n)))

            if marks is not None:
                markers = itertools.cycle(marks)

            if image_format == 'plotly':
                self.plotly_pca(plot_filename, Wt, pvar, PCs, eigenvalues,
                                cols, plot_title)
            else:
                ax1.axhline(y=0, color="black", linestyle="dotted", zorder=1)
                ax1.axvline(x=0, color="black", linestyle="dotted", zorder=2)
                for i in range(n):
                    color = next(colors)
                    marker = next(markers)
                    if isinstance(color, np.ndarray):
                        color = pltcolors.to_hex(color, keep_alpha=True)
                    ax1.scatter(Wt[PCs[0] - 1, i],
                                Wt[PCs[1] - 1, i],
                                marker=marker,
                                color=color,
                                s=150,
                                label=self.labels[i],
                                zorder=i + 3)
                if plot_title == '':
                    ax1.set_title('PCA')
                else:
                    ax1.set_title(plot_title)
                ax1.set_xlabel('PC{} ({:4.1f}% of var. explained)'.format(
                    PCs[0], 100.0 * pvar[PCs[0] - 1]))
                ax1.set_ylabel('PC{} ({:4.1f}% of var. explained)'.format(
                    PCs[1], 100.0 * pvar[PCs[1] - 1]))
                lgd = ax1.legend(scatterpoints=1,
                                 loc='center left',
                                 borderaxespad=0.5,
                                 bbox_to_anchor=(1, 0.5),
                                 prop={'size': 12},
                                 markerscale=0.9)

                # Scree plot
                ind = np.arange(n_bars)  # the x locations for the groups
                width = 0.35  # the width of the bars

                if mpl.__version__ >= "2.0.0":
                    ax2.bar(2 * width + ind, eigenvalues[:n_bars], width * 2)
                else:
                    ax2.bar(width + ind, eigenvalues[:n_bars], width * 2)
                ax2.set_ylabel('Eigenvalue')
                ax2.set_xlabel('Principal Component')
                ax2.set_title('Scree plot')
                ax2.set_xticks(ind + width * 2)
                ax2.set_xticklabels(ind + 1)

                ax3 = ax2.twinx()
                ax3.axhline(y=1, color="black", linestyle="dotted")
                ax3.plot(width * 2 + ind, pvar.cumsum()[:n], "r-")
                ax3.plot(width * 2 + ind,
                         pvar.cumsum()[:n],
                         "wo",
                         markeredgecolor="black")
                ax3.set_ylim([0, 1.05])
                ax3.set_ylabel('Cumulative variability')

                plt.subplots_adjust(top=3.85)
                plt.tight_layout()
                plt.savefig(plot_filename,
                            format=image_format,
                            bbox_extra_artists=(lgd, ),
                            bbox_inches='tight')
                plt.close()

        return Wt, eigenvalues
Example #47
0
LPIR_RESILIENT["scales"] = round(
    ((LPIR_RESILIENT.resilience / max(LPIR_RESILIENT.resilience)) * 3) + 0.1,
    1)
## rescale all data by an arbitrary number
LPIR_RESILIENT["resilience"] = round(LPIR_RESILIENT.resilience, 1)

# add colors based on 'congestion_index'
vmin = min(LPIR_RESILIENT.scales)
vmax = max(LPIR_RESILIENT.scales)
# Try to map values to colors in hex
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
mapper = plt.cm.ScalarMappable(
    norm=norm, cmap=plt.cm.YlGnBu
)  # scales of Reds (or "coolwarm" , "bwr", °cool°)  gist_yarg --> grey to black, YlOrRd
LPIR_RESILIENT['color'] = LPIR_RESILIENT['scales'].apply(
    lambda x: mcolors.to_hex(mapper.to_rgba(x)))

# add colors to map
my_map = plot_graph_folium_FK(LPIR_RESILIENT,
                              graph_map=None,
                              popup_attribute=None,
                              zoom=15,
                              fit_bounds=True,
                              edge_width=3.5,
                              edge_opacity=0.5)
style = {'fillColor': '#00000000', 'color': '#00000000'}
# add 'u' and 'v' as highligths for each edge (in blue)
folium.GeoJson(
    # data to plot
    LPIR_RESILIENT[['u', 'v', 'scales', 'resilience', 'length',
                    'geometry']].to_json(),
Example #48
0
source1 = dict(x=coords_filtered.ra.deg,
               y=coords_filtered.dec.deg,
               name=target_names,
               B=np.ma.asarray(Bs),
               V=np.ma.asarray(Vs),
               teff=np.ma.asarray(teffs),
               index=target_names,
               s_size=[np.log(100 * s.mean()) for s in sindices],
               S=[
                   ', '.join([
                       "{0:.2f}±{1:.2f}".format(s, e)
                       for s, e in zip(sinds, errors)
                   ]) for sinds, errors in zip(sindices, errs)
               ],
               BminusV=list(
                   map(lambda x: to_hex(plt.cm.Spectral_r((x - 0.2) / 2)),
                       bminusv)))

p1 = figure(tools=TOOLS, tooltips=TOOLTIPS)
p1.scatter(source=source1,
           x='x',
           y='y',
           radius='s_size',
           fill_color='BminusV',
           line_color=None)

p1.yaxis.axis_label = "Declination [deg]"
p1.xaxis.axis_label = "Right Ascension [deg]"

p1.add_tools(
    HoverTool(
def plot_subbranch(target_color, cluster_i, tree, loading, cluster_sizes, title=None,
                   size=2.3, dpi=300, plot_loc=None):
    sns.set_style('white')
    colormap = sns.diverging_palette(220,15,n=100,as_cmap=True)
    # get variables in subbranch based on coloring
    curr_color = tree['color_list'][0]
    start = 0
    for i, color in enumerate(tree['color_list']):
        if color != curr_color:
            end = i
            if curr_color == to_hex(target_color):
                break
            if color != "#808080":
                start = i
            curr_color = color
    
    if (end-start)+1 != cluster_sizes[cluster_i]:
        return
    
    # get subset of loading
    cumsizes = np.cumsum(cluster_sizes)
    if cluster_i==0:
        loading_start = 0
    else:
        loading_start = cumsizes[cluster_i-1]
    subset_loading = loading.T.iloc[:,loading_start:cumsizes[cluster_i]]
    
    # plotting
    N = subset_loading.shape[1]
    length = N*.05
    dendro_size = [0,.746,length,.12]
    heatmap_size = [0,.5,length,.25]
    fig = plt.figure(figsize=(size,size*2))
    dendro_ax = fig.add_axes(dendro_size) 
    heatmap_ax = fig.add_axes(heatmap_size)
    cbar_size = [length+.22, .5, .05, .25]
    factor_avg_size = [length+.01,.5,.2,.25]
    factor_avg_ax = fig.add_axes(factor_avg_size)
    cbar_ax = fig.add_axes(cbar_size)
    #subset_loading.columns = [col.replace(': ',':\n', 1) for col in subset_loading.columns]
    plot_tree(tree, range(start, end), dendro_ax, linewidth=size/2)
    dendro_ax.set_xticklabels('')
    
    max_val = np.max(loading.values)
    # if max_val is high, just make it 1
    if max_val > .9:
        max_val = 1
    sns.heatmap(subset_loading, ax=heatmap_ax, 
                cbar=True,
                cbar_ax=cbar_ax,
                cbar_kws={'ticks': [-max_val, 0, max_val]},
                yticklabels=True,
                vmin=-max_val,
                vmax=max_val,
                cmap=colormap,)
    yn, xn = subset_loading.shape
    tick_label_size = size*30/max(yn, 8)
    heatmap_ax.tick_params(labelsize=tick_label_size, length=size*.5, 
                           width=size/5, pad=size)
    heatmap_ax.set_yticklabels(heatmap_ax.get_yticklabels(), rotation=0)
    heatmap_ax.set_xticks([i+.5 for i in range(0,subset_loading.shape[1])])
    heatmap_ax.set_xticklabels([str(i) for i in range(1,subset_loading.shape[1]+1)], 
                                size=size*2, rotation=0, ha='center')

    avg_factors = abs(subset_loading).mean(1)
    # format cbar axis
    cbar_ax.set_yticklabels([format_num(-max_val), 0, format_num(max_val)])
    cbar_ax.tick_params(axis='y', length=0)
    cbar_ax.tick_params(labelsize=size*3)
    cbar_ax.set_ylabel('Factor Loading', rotation=-90, fontsize=size*3,
                       labelpad=size*2)
    # add axis labels as text above
    text_ax = fig.add_axes([-.22,.44-.02*N,.4,.02*N]) 
    for spine in ['top','right','bottom','left']:
        text_ax.spines[spine].set_visible(False)
    for i, label in enumerate(subset_loading.columns):
        text_ax.text(0, 1-i/N, str(i+1)+'.', fontsize=size*2.8, ha='right')
        text_ax.text(.1, 1-i/N, label, fontsize=size*3)
    text_ax.tick_params(which='both', labelbottom=False, labelleft=False,
                        bottom=False, left=False)
    # average factor bar                
    avg_factors[::-1].plot(kind='barh', ax = factor_avg_ax, width=.7,
                     color= tree['color_list'][start])
    factor_avg_ax.set_xlim(0, max_val)
    #factor_avg_ax.set_xticks([max(avg_factors)])
    #factor_avg_ax.set_xticklabels([format_num(max(avg_factors))])
    factor_avg_ax.set_xticklabels('')
    factor_avg_ax.set_yticklabels('')
    factor_avg_ax.tick_params(length=0)
    factor_avg_ax.spines['top'].set_visible(False)
    factor_avg_ax.spines['bottom'].set_visible(False)
    factor_avg_ax.spines['left'].set_visible(False)
    factor_avg_ax.spines['right'].set_visible(False)
        
    # title and axes styling of dendrogram
    if title:
        dendro_ax.set_title(title, fontsize=size*3, y=1.05, fontweight='bold')
    dendro_ax.get_yaxis().set_visible(False)
    dendro_ax.spines['top'].set_visible(False)
    dendro_ax.spines['right'].set_visible(False)
    dendro_ax.spines['bottom'].set_visible(False)
    dendro_ax.spines['left'].set_visible(False)
    if plot_loc is not None:
        save_figure(fig, plot_loc, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    else:
        return fig
Example #50
0
def calculate_color(density_dict,
                    colorscheme=None,
                    counter_data=None,
                    invert=False):
    """
    Transforms the population densities to a gmap color for mapping

    Parameters
    ----------
    density_dict : dict
        Dictionary with districts as a key and its population density as a
        value
    colorscheme : string
        It defines the colorscheme that will be used in the painting of the
        districts. It supports: 'Greys','viridis','inferno and 'plasma'.
    counter_data : dictionary
        If supplied, it will help to paint the districts according to the
        number of activities that each one has.
    invert : boolean
        If true, it inverts the colors of the colorscheme.

    Returns
    -------
    gmaps_color : dict
        Dictionary with districts as a key and its gmap color
    """
    # get the biggest population density in the set
    biggest_density = max([x for _, x in density_dict.items()])

    # normalize the density according to the maximum
    normalized_values = {
        key: val / biggest_density
        for key, val in density_dict.items()
    }

    if invert:
        # invert values v-> (1-v)
        normalized_values = {
            key: 1 - val
            for key, val in normalized_values.items()
        }
    """
        if counter_data is None:
        # get the biggest population density in the set
        biggest_density = max([x for _, x in density_dict.items()])

        # normalize the density according to the maximum
        normalized_values = {key: val / biggest_density for key, val in
                             density_dict.items()}

        if invert:
            # invert values v-> (1-v)
            normalized_values = {key: 1 - val for key, val in
                                 normalized_values.items()}
    else:
        for district_name in counter_data.items():
            # get the biggest number of activities per district in the set
            biggest_density = max([x for _, x in density_dict.items()])

    """

    # define matplotlib colorscheme
    if colorscheme == 'Greys':
        colorscheme_func = Greys
    elif colorscheme == 'plasma':
        colorscheme_func = plasma
    elif colorscheme == 'inferno':
        colorscheme_func = inferno
    elif colorscheme == 'viridis':
        colorscheme_func = viridis
    else:
        colorscheme_func = Greys

    # transform the normalized density to a matplotlib color
    mpl_color = {
        key: colorscheme_func(val)
        for key, val in normalized_values.items()
    }

    # transform from a matplotlib color to a valid CSS color
    gmaps_color = {
        key: to_hex(val, keep_alpha=False)
        for key, val in mpl_color.items()
    }

    return gmaps_color
Example #51
0
def figure_edit(axes, parent=None):
    """Edit matplotlib figure options"""
    sep = (None, None)  # separator

    # Get / General
    # Cast to builtin floats as they have nicer reprs.
    xmin, xmax = map(float, axes.get_xlim())
    ymin, ymax = map(float, axes.get_ylim())
    general = [
        ('Title', axes.get_title()),
        ('Title size', matplotlib.rcParams['axes.titlesize']),
        ('Label size', matplotlib.rcParams['axes.labelsize']),
        ('Legend size', matplotlib.rcParams['legend.fontsize']),
        sep,
        sep,
        (None, "<b>X-Axis</b>"),
        ('Left', xmin),
        ('Right', xmax),
        ('Label', axes.get_xlabel()),
        ('Scale', [axes.get_xscale(), 'linear', 'log', 'logit']),
        ('Tick label size', matplotlib.rcParams['xtick.labelsize']),
        sep,
        (None, "<b>Y-Axis</b>"),
        ('Bottom', ymin),
        ('Top', ymax),
        ('Label', axes.get_ylabel()),
        ('Scale', [axes.get_yscale(), 'linear', 'log', 'logit']),
        ('Tick label size', matplotlib.rcParams['ytick.labelsize']),
        sep,
        ('(Re-)Generate automatic legend', False),
    ]

    # Save the unit data
    xconverter = axes.xaxis.converter
    yconverter = axes.yaxis.converter
    xunits = axes.xaxis.get_units()
    yunits = axes.yaxis.get_units()

    # Sorting for default labels (_lineXXX, _imageXXX).
    def cmp_key(label):
        match = re.match(r"(_line|_image)(\d+)", label)
        if match:
            return match.group(1), int(match.group(2))
        else:
            return label, 0

    # Get / Curves
    linedict = {}
    lines = set()
    unnamed = 0

    # add containered lines by legend name (errorbar)
    for container in axes.containers:
        label = container.get_label()
        if label == '_nolegend_':
            unnamed += 1
            label = 'unnamed #%d' % unnamed
        for obj in container:
            if isinstance(obj, matplotlib.lines.Line2D):
                linedict[label] = obj
                lines.add(obj)
                break

    # add normal lines
    for line in axes.get_lines():
        if line in lines:
            continue
        label = line.get_label()
        if label == '_nolegend_':
            unnamed += 1
            label = 'unnamed #%d' % unnamed
        linedict[label] = line

    curves = []

    def prepare_data(d, init):
        """Prepare entry for FormLayout.

        `d` is a mapping of shorthands to style names (a single style may
        have multiple shorthands, in particular the shorthands `None`,
        `"None"`, `"none"` and `""` are synonyms); `init` is one shorthand
        of the initial style.

        This function returns an list suitable for initializing a
        FormLayout combobox, namely `[initial_name, (shorthand,
        style_name), (shorthand, style_name), ...]`.
        """
        # Drop duplicate shorthands from dict (by overwriting them during
        # the dict comprehension).
        name2short = {name: short for short, name in d.items()}
        # Convert back to {shorthand: name}.
        short2name = {short: name for name, short in name2short.items()}
        # Find the kept shorthand for the style specified by init.
        canonical_init = name2short[d[init]]
        # Sort by representation and prepend the initial value.
        return ([canonical_init] + sorted(
            short2name.items(), key=lambda short_and_name: short_and_name[1]))

    curvelabels = sorted(linedict, key=cmp_key)
    for label in curvelabels:
        line = linedict[label]
        color = mcolors.to_hex(mcolors.to_rgba(line.get_color(),
                                               line.get_alpha()),
                               keep_alpha=True)
        ec = mcolors.to_hex(mcolors.to_rgba(line.get_markeredgecolor(),
                                            line.get_alpha()),
                            keep_alpha=True)
        fc = mcolors.to_hex(mcolors.to_rgba(line.get_markerfacecolor(),
                                            line.get_alpha()),
                            keep_alpha=True)
        curvedata = [
            ('Label', label), sep, (None, '<b>Line</b>'),
            ('Line style', prepare_data(LINESTYLES, line.get_linestyle())),
            ('Draw style', prepare_data(DRAWSTYLES, line.get_drawstyle())),
            ('Width', line.get_linewidth()), ('Color (RGBA)', color), sep,
            (None, '<b>Marker</b>'),
            ('Style', prepare_data(MARKERS, line.get_marker())),
            ('Size', line.get_markersize()), ('Face color (RGBA)', fc),
            ('Edge color (RGBA)', ec)
        ]
        curves.append([curvedata, label, ""])
    # Is there a curve displayed?
    has_curve = bool(curves)

    # Get / Images
    imagedict = {}
    for image in axes.get_images():
        label = image.get_label()
        if label == '_nolegend_':
            continue
        imagedict[label] = image
    imagelabels = sorted(imagedict, key=cmp_key)
    images = []
    cmaps = [(cmap, name) for name, cmap in sorted(cm.cmap_d.items())]
    for label in imagelabels:
        image = imagedict[label]
        cmap = image.get_cmap()
        if cmap not in cm.cmap_d.values():
            cmaps = [(cmap, cmap.name)] + cmaps
        low, high = image.get_clim()
        imagedata = [('Label', label), ('Colormap', [cmap.name] + cmaps),
                     ('Min. value', low), ('Max. value', high),
                     ('Interpolation', [image.get_interpolation()] +
                      [(name, name)
                       for name in sorted(mimage.interpolations_names)])]
        images.append([imagedata, label, ""])
    # Is there an image displayed?
    has_image = bool(images)

    datalist = [(general, "Axes", "")]
    if curves:
        datalist.append((curves, "Curves", ""))
    if images:
        datalist.append((images, "Images", ""))

    def apply_callback(data):
        """This function will be called to apply changes"""
        orig_xlim = axes.get_xlim()
        orig_ylim = axes.get_ylim()

        general = data.pop(0)
        curves = data.pop(0) if has_curve else []
        images = data.pop(0) if has_image else []
        if data:
            raise ValueError("Unexpected field")

        # Set / General
        (title, titlesize, labelsize, legendsize, xmin, xmax, xlabel, xscale,
         xticksize, ymin, ymax, ylabel, yscale, yticksize,
         generate_legend) = general

        if axes.get_xscale() != xscale:
            axes.set_xscale(xscale)
        if axes.get_yscale() != yscale:
            axes.set_yscale(yscale)

        axes.set_title(title)
        axes.set_xlim(xmin, xmax)
        axes.set_xlabel(xlabel)
        axes.set_ylim(ymin, ymax)
        axes.set_ylabel(ylabel)

        orig_sizes = (matplotlib.rcParams['axes.titlesize'],
                      matplotlib.rcParams['axes.labelsize'],
                      matplotlib.rcParams['legend.fontsize'],
                      matplotlib.rcParams['xtick.labelsize'],
                      matplotlib.rcParams['ytick.labelsize'])
        new_sizes = (titlesize, labelsize, legendsize, xticksize, yticksize)

        # Restore font data
        matplotlib.rcParams['axes.titlesize'] = titlesize
        matplotlib.rcParams['axes.labelsize'] = labelsize
        matplotlib.rcParams['legend.fontsize'] = legendsize
        matplotlib.rcParams['xtick.labelsize'] = xticksize
        matplotlib.rcParams['ytick.labelsize'] = yticksize

        # Restore the unit data
        axes.xaxis.converter = xconverter
        axes.yaxis.converter = yconverter
        axes.xaxis.set_units(xunits)
        axes.yaxis.set_units(yunits)
        axes.xaxis._update_axisinfo()
        axes.yaxis._update_axisinfo()

        # Set / Curves
        for index, curve in enumerate(curves):
            line = linedict[curvelabels[index]]
            (label, linestyle, drawstyle, linewidth, color, marker, markersize,
             markerfacecolor, markeredgecolor) = curve
            line.set_label(label)
            line.set_linestyle(linestyle)
            line.set_drawstyle(drawstyle)
            line.set_linewidth(linewidth)
            rgba = mcolors.to_rgba(color)
            line.set_alpha(None)
            line.set_color(rgba)
            if marker is not 'none':
                line.set_marker(marker)
                line.set_markersize(markersize)
                line.set_markerfacecolor(markerfacecolor)
                line.set_markeredgecolor(markeredgecolor)

        # Set / Images
        for index, image_settings in enumerate(images):
            image = imagedict[imagelabels[index]]
            label, cmap, low, high, interpolation = image_settings
            image.set_label(label)
            image.set_cmap(cm.get_cmap(cmap))
            image.set_clim(*sorted([low, high]))
            image.set_interpolation(interpolation)

        # re-generate legend, if checkbox is checked
        if generate_legend:
            draggable = None
            ncol = 1
            if axes.legend_ is not None:
                old_legend = axes.get_legend()
                draggable = old_legend._draggable is not None
                ncol = old_legend._ncol
            new_legend = axes.legend(ncol=ncol)
            if new_legend:
                new_legend.draggable(draggable)

        # Redraw
        figure = axes.get_figure()
        figure.canvas.draw()
        if not (axes.get_xlim() == orig_xlim and axes.get_ylim() == orig_ylim):
            figure.canvas.toolbar.push_current()
        if orig_sizes != new_sizes:
            figure.canvas.ufit_replot()
            figure.tight_layout(pad=2)

    data = formlayout.fedit(datalist,
                            title="Figure options",
                            parent=parent,
                            icon=get_icon('qt4_editor_options.svg'),
                            apply=apply_callback)
    if data is not None:
        apply_callback(data)
Example #52
0
def color_to_num(color_string):
    for i in range(0, len(color_string)):
        color_string[i] = int(color_string[i])
    return to_hex(tuple(color_string))
Example #53
0
def test_xkcd():
    assert mcolors.to_hex("blue") == "#0000ff"
    assert mcolors.to_hex("xkcd:blue") == "#0343df"
Example #54
0
    frac[f_ion > 0.1] = b'med'  # orange
    frac[f_ion > 0.2] = b'high'  # red
    frac[(f_ion > 0.2) & (temperature < 1e5)] = b'phot'
    return frac


# set up the new temperature colormap
temp_colors = sns.blend_palette(
    ('salmon', "#984ea3", "#4daf4a", "#ffe34d", 'darkorange'), n_colors=17)
phase_color_labels = [b'cold1', b'cold2', b'cold3', b'cool', b'cool1', b'cool2',
                      b'cool3', b'warm', b'warm1', b'warm2', b'warm3', b'hot',
                      b'hot1', b'hot2', b'hot3']
temperature_discrete_cmap = mpl.colors.ListedColormap(temp_colors)
new_phase_color_key = collections.OrderedDict()
for i in np.arange(np.size(phase_color_labels)):
    new_phase_color_key[phase_color_labels[i]] = to_hex(temp_colors[i])

def new_categorize_by_temp(temp):
    """ define the temp category strings"""
    phase = np.chararray(np.size(temp), 5)
    phase[temp < 9.] = b'hot3'
    phase[temp < 6.6] = b'hot2'
    phase[temp < 6.4] = b'hot1'
    phase[temp < 6.2] = b'hot'
    phase[temp < 6.] = b'warm3'
    phase[temp < 5.8] = b'warm2'
    phase[temp < 5.6] = b'warm1'
    phase[temp < 5.4] = b'warm'
    phase[temp < 5.2] = b'cool3'
    phase[temp < 5.] = b'cool2'
    phase[temp < 4.8] = b'cool1'
Example #55
0
    def plot_cdf(self,
                 workload='jankbench',
                 metric='frame_total_duration',
                 threshold=16,
                 tag='.*',
                 kernel='.*',
                 test='.*'):
        """
        Display cumulative distribution functions of a certain metric

        Draws CDFs of metrics in the results. Check ``workloads`` and
        ``workload_available_metrics`` to find the available workloads and
        metrics. Check ``tags``, ``tests`` and ``kernels`` to find the
        names that results can be filtered against.

        The most likely use-case for this is plotting frame rendering times
        under Jankbench, so default parameters are provided to make this easy.

        :param workload: Name of workload to display metrics for
        :param metric: Name of metric to display

        :param threshold: Value to highlight in the plot - the likely use for
                          this is highlighting the maximum acceptable
                          frame-rendering time in order to see at a glance the
                          rough proportion of frames that were rendered in time.

        :param tag: regular expression to filter tags that should be plotted
        :param kernel: regular expression to filter kernels that should be plotted
        :param tag: regular expression to filter tags that should be plotted

        :param by: List of identifiers to group output as in DataFrame.groupby.
        """
        df = self._get_metric_df(workload, metric, tag, kernel, test)
        if df is None:
            return

        test_cnt = len(df.groupby(['test', 'tag', 'kernel']))
        colors = iter(cm.rainbow(np.linspace(0, 1, test_cnt + 1)))

        fig, axes = plt.subplots()
        axes.axvspan(0, threshold, facecolor='g', alpha=0.1)

        labels = []
        lines = []
        for keys, df in df.groupby(['test', 'tag', 'kernel']):
            labels.append("{:16s}: {:32s}".format(keys[2], keys[1]))
            color = next(colors)
            cdf = self._get_cdf(df['value'], threshold)
            [units] = df['units'].unique()
            ax = cdf.df.plot(ax=axes, legend=False, xlim=(0,None), figsize=(16, 6),
                             title='Total duration CDF ({:.1f}% within {} [{}] threshold)'\
                             .format(100. * cdf.below, threshold, units),
                             label=test,
                             color=to_hex(color))
            lines.append(ax.lines[-1])
            axes.axhline(y=cdf.below,
                         linewidth=1,
                         linestyle='--',
                         color=to_hex(color))
            self._log.debug("%-32s: %-32s: %.1f", keys[2], keys[1],
                            100. * cdf.below)

        axes.grid(True)
        axes.legend(lines, labels)
        plt.show()
Example #56
0
def embedding(
    adata: AnnData,
    basis: str,
    *,
    color: Union[str, Sequence[str], None] = None,
    gene_symbols: Optional[str] = None,
    use_raw: Optional[bool] = None,
    sort_order: bool = True,
    edges: bool = False,
    edges_width: float = 0.1,
    edges_color: Union[str, Sequence[float], Sequence[str]] = 'grey',
    neighbors_key: Optional[str] = None,
    arrows: bool = False,
    arrows_kwds: Optional[Mapping[str, Any]] = None,
    groups: Optional[str] = None,
    components: Union[str, Sequence[str]] = None,
    layer: Optional[str] = None,
    projection: Literal['2d', '3d'] = '2d',
    # image parameters
    img_key: Optional[str] = None,
    crop_coord: Tuple[int, int, int, int] = None,
    alpha_img: float = 1.0,
    bw: bool = False,
    library_id: str = None,
    #
    color_map: Union[Colormap, str, None] = None,
    cmap: Union[Colormap, str, None] = None,
    palette: Union[str, Sequence[str], Cycler, None] = None,
    na_color: ColorLike = "lightgray",
    na_in_legend: bool = True,
    size: Union[float, Sequence[float], None] = None,
    frameon: Optional[bool] = None,
    legend_fontsize: Union[int, float, _FontSize, None] = None,
    legend_fontweight: Union[int, _FontWeight] = 'bold',
    legend_loc: str = 'right margin',
    legend_fontoutline: Optional[int] = None,
    vmax: Union[VMinMax, Sequence[VMinMax], None] = None,
    vmin: Union[VMinMax, Sequence[VMinMax], None] = None,
    add_outline: Optional[bool] = False,
    outline_width: Tuple[float, float] = (0.3, 0.05),
    outline_color: Tuple[str, str] = ('black', 'white'),
    ncols: int = 4,
    hspace: float = 0.25,
    wspace: Optional[float] = None,
    title: Union[str, Sequence[str], None] = None,
    show: Optional[bool] = None,
    save: Union[bool, str, None] = None,
    ax: Optional[Axes] = None,
    return_fig: Optional[bool] = None,
    **kwargs,
) -> Union[Figure, Axes, None]:
    """\
    Scatter plot for user specified embedding basis (e.g. umap, pca, etc)

    Parameters
    ----------
    basis
        Name of the `obsm` basis to use.
    {adata_color_etc}
    {edges_arrows}
    {scatter_bulk}
    {show_save_ax}

    Returns
    -------
    If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
    """
    check_projection(projection)
    sanitize_anndata(adata)

    # Setting up color map for continuous values
    if color_map is not None:
        if cmap is not None:
            raise ValueError("Cannot specify both `color_map` and `cmap`.")
        else:
            cmap = color_map
    cmap = copy(get_cmap(cmap))
    cmap.set_bad(na_color)
    kwargs["cmap"] = cmap

    # Prevents warnings during legend creation
    na_color = colors.to_hex(na_color, keep_alpha=True)

    if size is not None:
        kwargs['s'] = size
    if 'edgecolor' not in kwargs:
        # by default turn off edge color. Otherwise, for
        # very small sizes the edge will not reduce its size
        # (https://github.com/theislab/scanpy/issues/293)
        kwargs['edgecolor'] = 'none'

    if groups:
        if isinstance(groups, str):
            groups = [groups]

    args_3d = dict(projection='3d') if projection == '3d' else {}

    # Deal with Raw
    if use_raw is None:
        # check if adata.raw is set
        use_raw = layer is None and adata.raw is not None
    if use_raw and layer is not None:
        raise ValueError(
            "Cannot use both a layer and the raw representation. Was passed:"
            f"use_raw={use_raw}, layer={layer}."
        )

    if wspace is None:
        #  try to set a wspace that is not too large or too small given the
        #  current figure size
        wspace = 0.75 / rcParams['figure.figsize'][0] + 0.02
    if adata.raw is None and use_raw:
        raise ValueError(
            "`use_raw` is set to True but AnnData object does not have raw. "
            "Please check."
        )
    # turn color into a python list
    color = [color] if isinstance(color, str) or color is None else list(color)
    if title is not None:
        # turn title into a python list if not None
        title = [title] if isinstance(title, str) else list(title)

    # get the points position and the components list
    # (only if components is not None)
    data_points, components_list = _get_data_points(
        adata, basis, projection, components, img_key, library_id
    )

    # Setup layout.
    # Most of the code is for the case when multiple plots are required
    # 'color' is a list of names that want to be plotted.
    # Eg. ['Gene1', 'louvain', 'Gene2'].
    # component_list is a list of components [[0,1], [1,2]]
    if (
        not isinstance(color, str)
        and isinstance(color, cabc.Sequence)
        and len(color) > 1
    ) or len(components_list) > 1:
        if ax is not None:
            raise ValueError(
                "Cannot specify `ax` when plotting multiple panels "
                "(each for a given value of 'color')."
            )
        if len(components_list) == 0:
            components_list = [None]

        # each plot needs to be its own panel
        num_panels = len(color) * len(components_list)
        fig, grid = _panel_grid(hspace, wspace, ncols, num_panels)
    else:
        if len(components_list) == 0:
            components_list = [None]
        grid = None
        if ax is None:
            fig = pl.figure()
            ax = fig.add_subplot(111, **args_3d)

    # turn vmax and vmin into a sequence
    if isinstance(vmax, str) or not isinstance(vmax, cabc.Sequence):
        vmax = [vmax]
    if isinstance(vmin, str) or not isinstance(vmin, cabc.Sequence):
        vmin = [vmin]

    if 's' in kwargs:
        size = kwargs.pop('s')

    if size is not None:
        # check if size is any type of sequence, and if so
        # set as ndarray
        import pandas.core.series

        if (
            size is not None
            and isinstance(size, (cabc.Sequence, pandas.core.series.Series, np.ndarray))
            and len(size) == adata.shape[0]
        ):
            size = np.array(size, dtype=float)
    else:
        size = 120000 / adata.shape[0]

    ###
    # make the plots
    axs = []
    import itertools

    idx_components = range(len(components_list))

    # use itertools.product to make a plot for each color and for each component
    # For example if color=[gene1, gene2] and components=['1,2, '2,3'].
    # The plots are: [
    #     color=gene1, components=[1,2], color=gene1, components=[2,3],
    #     color=gene2, components = [1, 2], color=gene2, components=[2,3],
    # ]
    for count, (value_to_plot, component_idx) in enumerate(
        itertools.product(color, idx_components)
    ):
        color_source_vector = _get_color_source_vector(
            adata,
            value_to_plot,
            layer=layer,
            use_raw=use_raw,
            gene_symbols=gene_symbols,
            groups=groups,
        )
        color_vector, categorical = _color_vector(
            adata,
            value_to_plot,
            color_source_vector,
            palette=palette,
            na_color=na_color,
        )

        ### Order points
        order = slice(None)
        if sort_order is True and value_to_plot is not None and categorical is False:
            # Higher values plotted on top, null values on bottom
            order = np.argsort(-color_vector, kind="stable")[::-1]
        elif sort_order and categorical:
            # Null points go on bottom
            order = np.argsort(~pd.isnull(color_source_vector), kind="stable")
        # Set orders
        if isinstance(size, np.ndarray):
            size = np.array(size)[order]
        color_source_vector = color_source_vector[order]
        color_vector = color_vector[order]
        _data_points = data_points[component_idx][order, :]

        # if plotting multiple panels, get the ax from the grid spec
        # else use the ax value (either user given or created previously)
        if grid:
            ax = pl.subplot(grid[count], **args_3d)
            axs.append(ax)
        if not (settings._frameon if frameon is None else frameon):
            ax.axis('off')
        if title is None:
            if value_to_plot is not None:
                ax.set_title(value_to_plot)
            else:
                ax.set_title('')
        else:
            try:
                ax.set_title(title[count])
            except IndexError:
                logg.warning(
                    "The title list is shorter than the number of panels. "
                    "Using 'color' value instead for some plots."
                )
                ax.set_title(value_to_plot)

        # check vmin and vmax options
        if categorical:
            kwargs['vmin'] = kwargs['vmax'] = None
        else:
            kwargs['vmin'], kwargs['vmax'] = _get_vmin_vmax(
                vmin, vmax, count, color_vector
            )

        # make the scatter plot
        if projection == '3d':
            cax = ax.scatter(
                _data_points[:, 0],
                _data_points[:, 1],
                _data_points[:, 2],
                marker=".",
                c=color_vector,
                rasterized=settings._vector_friendly,
                **kwargs,
            )
        else:
            if img_key is not None:
                # had to return size_spot cause spot size is set according
                # to the image to be plotted
                img_processed, img_coord, size_spot, cmap_img = _process_image(
                    adata, data_points, img_key, crop_coord, size, library_id, bw
                )
                ax.imshow(img_processed, cmap=cmap_img, alpha=alpha_img)
                ax.set_xlim(img_coord[0], img_coord[1])
                ax.set_ylim(img_coord[3], img_coord[2])
            elif img_key is None and library_id is not None:
                # order of magnitude similar to public visium
                size_spot = 70 * size

            scatter = (
                partial(ax.scatter, s=size, plotnonfinite=True)
                if library_id is None
                else partial(circles, s=size_spot, ax=ax)
            )

            if add_outline:
                # the default outline is a black edge followed by a
                # thin white edged added around connected clusters.
                # To add an outline
                # three overlapping scatter plots are drawn:
                # First black dots with slightly larger size,
                # then, white dots a bit smaller, but still larger
                # than the final dots. Then the final dots are drawn
                # with some transparency.

                bg_width, gap_width = outline_width
                point = np.sqrt(size)
                gap_size = (point + (point * gap_width) * 2) ** 2
                bg_size = (np.sqrt(gap_size) + (point * bg_width) * 2) ** 2
                # the default black and white colors can be changes using
                # the contour_config parameter
                bg_color, gap_color = outline_color

                # remove edge from kwargs if present
                # because edge needs to be set to None
                kwargs['edgecolor'] = 'none'

                # remove alpha for outline
                alpha = kwargs.pop('alpha') if 'alpha' in kwargs else None

                ax.scatter(
                    _data_points[:, 0],
                    _data_points[:, 1],
                    s=bg_size,
                    marker=".",
                    c=bg_color,
                    rasterized=settings._vector_friendly,
                    **kwargs,
                )
                ax.scatter(
                    _data_points[:, 0],
                    _data_points[:, 1],
                    s=gap_size,
                    marker=".",
                    c=gap_color,
                    rasterized=settings._vector_friendly,
                    **kwargs,
                )
                # if user did not set alpha, set alpha to 0.7
                kwargs['alpha'] = 0.7 if alpha is None else alpha

            cax = scatter(
                _data_points[:, 0],
                _data_points[:, 1],
                marker=".",
                c=color_vector,
                rasterized=settings._vector_friendly,
                **kwargs,
            )

        # remove y and x ticks
        ax.set_yticks([])
        ax.set_xticks([])
        if projection == '3d':
            ax.set_zticks([])

        # set default axis_labels
        name = _basis2name(basis)
        if components is not None:
            axis_labels = [name + str(x + 1) for x in components_list[component_idx]]
        elif projection == '3d':
            axis_labels = [name + str(x + 1) for x in range(3)]
        else:
            axis_labels = [name + str(x + 1) for x in range(2)]

        ax.set_xlabel(axis_labels[0])
        ax.set_ylabel(axis_labels[1])
        if projection == '3d':
            # shift the label closer to the axis
            ax.set_zlabel(axis_labels[2], labelpad=-7)
        ax.autoscale_view()

        if edges:
            _utils.plot_edges(ax, adata, basis, edges_width, edges_color, neighbors_key)
        if arrows:
            _utils.plot_arrows(ax, adata, basis, arrows_kwds)

        if value_to_plot is None:
            # if only dots were plotted without an associated value
            # there is not need to plot a legend or a colorbar
            continue

        if legend_fontoutline is not None:
            path_effect = [
                patheffects.withStroke(linewidth=legend_fontoutline, foreground='w')
            ]
        else:
            path_effect = None

        # Adding legends
        if categorical:
            _add_categorical_legend(
                ax,
                color_source_vector,
                palette=_get_palette(adata, value_to_plot),
                scatter_array=_data_points,
                legend_loc=legend_loc,
                legend_fontweight=legend_fontweight,
                legend_fontsize=legend_fontsize,
                legend_fontoutline=path_effect,
                na_color=na_color,
                na_in_legend=na_in_legend,
                multi_panel=bool(grid),
            )
        else:
            # TODO: na_in_legend should have some effect here
            pl.colorbar(cax, ax=ax, pad=0.01, fraction=0.08, aspect=30)

    if return_fig is True:
        return fig
    axs = axs if grid else ax
    _utils.savefig_or_show(basis, show=show, save=save)
    if show is False:
        return axs
Example #57
0
def test_color_names():
    assert mcolors.to_hex("blue") == "#0000ff"
    assert mcolors.to_hex("xkcd:blue") == "#0343df"
    assert mcolors.to_hex("tab:blue") == "#1f77b4"
Example #58
0
def bubble(adata,
           genes,
           group,
           gene_order=None,
           group_order=None,
           layer=None,
           theme=None,
           cmap=None,
           color_key=None,
           color_key_cmap='Spectral',
           background="white",
           pointsize=None,
           vmin=0,
           vmax=100,
           sym_c=False,
           alpha=0.8,
           edgecolor=None,
           linewidth=0,
           type='violin',
           sort='diagnoal',
           transpose=False,
           rotate_xlabel='horizontal',
           rotate_ylabel='horizontal',
           figsize=None,
           save_show_or_return='show',
           save_kwargs={},
           **kwargs):
    """Bubble plots generalized to velocity, acceleration, curvature.
    It supports either the `dot` or `violin` plot mode. This function is loosely based on
    https://github.com/QuKunLab/COVID-19/blob/master/step3_plot_umap_and_marker_gene_expression.ipynb

    # add sorting
    Parameters
    ----------
        adata: :class:`~anndata.AnnData`
            an Annodata object
        genes: `list`
            The gene list, i.e. marker gene or top acceleration, curvature genes, etc.
        group: `str`
            The column key in `adata.obs` that will be used to group cells.
        gene_order: `None` or `list` (default: `None`)
            The gene groups order that will show up in the resulting bubble plot.
        group_order: `None` or `list` (default: `None`)
            The cells groups order that will show up in the resulting bubble plot.
        layer: `None` or `str` (default: `None`)
            The layer of data to use for the bubble plot.
        theme: string (optional, default None)
            A color theme to use for plotting. A small set of
            predefined themes are provided which have relatively
            good aesthetics. Available themes are:
               * 'blue'
               * 'red'
               * 'green'
               * 'inferno'
               * 'fire'
               * 'viridis'
               * 'darkblue'
               * 'darkred'
               * 'darkgreen'
        cmap: string (optional, default 'Blues')
            The name of a matplotlib colormap to use for coloring
            or shading points. If no labels or values are passed
            this will be used for shading points according to
            density (largely only of relevance for very large
            datasets). If values are passed this will be used for
            shading according the value. Note that if theme
            is passed then this value will be overridden by the
            corresponding option of the theme.
        color_key: dict or array, shape (n_categories) (optional, default None)
            A way to assign colors to categoricals. This can either be
            an explicit dict mapping labels to colors (as strings of form
            '#RRGGBB'), or an array like object providing one color for
            each distinct category being provided in ``labels``. Either
            way this mapping will be used to color points according to
            the label. Note that if theme
            is passed then this value will be overridden by the
            corresponding option of the theme.
        color_key_cmap: string (optional, default 'Spectral')
            The name of a matplotlib colormap to use for categorical coloring.
            If an explicit ``color_key`` is not given a color mapping for
            categories can be generated from the label list and selecting
            a matching list of colors from the given colormap. Note
            that if theme
            is passed then this value will be overridden by the
            corresponding option of the theme.
        background: string or None (optional, default 'None`)
            The color of the background. Usually this will be either
            'white' or 'black', but any color name will work. Ideally
            one wants to match this appropriately to the colors being
            used for points etc. This is one of the things that themes
            handle for you. Note that if theme
            is passed then this value will be overridden by the
            corresponding option of the theme.
        pointsize: `None` or `float` (default: None)
            The scale of the point size. Actual point cell size is calculated as `500.0 / np.sqrt(adata.shape[0]) *
            pointsize`
        vmin: `float` (default: `0`)
            The percentage of minimal value to consider.
        vmax: `float` (default: `100`)
            The percentage of maximal value to consider.
        sym_c: `bool` (default: `False`)
            Whether do you want to make the limits of continuous color to be symmetric, normally this should be used for
            plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative values.
        alpha: `float` (default: `0.8`)
            alpha value of the plot
        edgecolor: `str` or `None` (default: `None`)
            The color of the edge of the dots when type is to be `dot`.
        linewidth: `str` or `None` (default: `None`)
            The width of the edge of the dots when type is to be `dot`.
        type: `str` (default: `violin`)
            The type of the bubble plot, one of `{'violin', 'dot'}`.
        figsize: `None` or `[float, float]` (default: None)
            The width and height of a figure.
        sort: `str` (default: `diagnol`)
            The method for sorting genes. Not implemented. Need to implement in 2021.
        transpose: `bool` (default: `False`)
            Whether to transpose the row/column of the resulting bubble plot. Gene and cell types are on x/y-axis by
            default.
        rotate_xlabel: `float` (default: `horizontal`)
            The angel to rotate the x-label.
        rotate_ylabel: `float` (default: `horizontal`)
            The angel to rotate the y-label.
        save_show_or_return: `str` {'save', 'show', 'return'} (default: `show`)
            Whether to save, show or return the figure.
        save_kwargs: `dict` (default: `{}`)
            A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function
            will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close":
            True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modify those keys
            according to your needs.
        kwargs:
            Additional arguments passed to plt.scatters or sns.violinplot.


    Returns
    -------
        Nothing but plot the bubble plots.

    """
    import matplotlib
    from matplotlib import rcParams
    from matplotlib.colors import to_hex
    import matplotlib.pyplot as plt
    import seaborn as sns

    if background is None:
        _background = rcParams.get("figure.facecolor")
        _background = to_hex(
            _background) if type(_background) is tuple else _background
        # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background)
    else:
        _background = background
        # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background)

    if theme is None:
        if _background in ["#ffffff", "black"]:
            _theme_ = "glasbey_dark"
        else:
            _theme_ = "glasbey_white"
    else:
        _theme_ = theme
    _cmap = _themes[_theme_]["cmap"] if cmap is None else cmap

    if layer is None:
        mapper = get_mapper()

        has_splicing, \
        has_labeling, \
        splicing_labeling, \
        has_protein = \
            adata.uns['pp']['has_splicing'], \
            adata.uns['pp']['has_labeling'], \
            adata.uns['pp']['splicing_labeling'], \
            adata.uns['pp']['has_protein']

        if splicing_labeling:
            layer = mapper['X_total'] if mapper[
                'X_total'] in adata.layers else 'X_total'
        elif has_labeling:
            layer = mapper['X_total'] if mapper[
                'X_total'] in adata.layers else 'X_total'
        else:
            layer = mapper['X_spliced'] if mapper[
                'X_spliced'] in adata.layers else 'X_spliced'

    if group not in adata.obs_keys():
        raise ValueError(
            f"argument group {group} is not a column name in `adata.obs`")

    genes = adata.var_names.intersection(set(genes)).to_list()
    if len(genes) == 0:
        raise ValueError(
            f"names from argument genes {genes} don't match any genes from `adata.var_names`."
        )

    # sort gene/cluster to update the orders
    uniq_groups = adata.obs[group].unique()
    if group_order is None:
        clusters = uniq_groups
    else:
        if not set(group_order).issubset(uniq_groups):
            raise ValueError(
                f"names from argument group_order {group_order} are not subsets of "
                f"`adata.obs[group].unique()`.")
        clusters = group_order

    if gene_order is None:
        genes = genes
    else:
        if not set(gene_order).issubset(genes):
            raise ValueError(
                f"names from argument gene_order {gene_order} is not a subset of "
                f"`adata.var_names.intersection(set(genes)).to_list()`.")
        genes = gene_order

    cells_df = adata.obs.get(group)
    gene_df = adata[:, genes].layers[layer]
    gene_df = gene_df.A if issparse(gene_df) else gene_df
    gene_df = pd.DataFrame(gene_df.T, index=genes, columns=adata.obs_names)

    xmin, xmax = gene_df.quantile(vmin / 100,
                                  axis=1), gene_df.quantile(vmax / 100, axis=1)
    if sym_c:
        _vmin, _vmax = np.zeros_like(xmin), np.zeros_like(xmax)
        i = 0
        for a, b in zip(xmin, xmax):
            bounds = np.nanmax([np.abs(a), b])
            bounds = bounds * np.array([-1, 1])
            _vmin[i], _vmax[i] = bounds
            i += 1
        xmin, xmax = _vmin, _vmax

    point_size = (16000.0 /
                  np.sqrt(adata.shape[0]) if pointsize is None else 16000.0 /
                  (len(genes) * len(clusters)) * pointsize)

    if color_key is None:
        cmap_ = matplotlib.cm.get_cmap(color_key_cmap)
        cmap_.set_bad("lightgray")
        unique_labels = np.unique(clusters)
        num_labels = unique_labels.shape[0]
        color_key = plt.get_cmap(color_key_cmap)(np.linspace(0, 1, num_labels))

    if figsize is None:
        width = 6 * len(genes) / 14 if transpose else 9 * len(genes) / 14
        height = 4.5 * len(clusters) / 14 if transpose else 4.5 * len(
            genes) / 14
        figsize = (height, width) if transpose else (width, height)
    else:
        figsize = figsize[::-1] if transpose else figsize

    # scatter_kwargs = dict(
    #     alpha=0.8, s=point_size, edgecolor=None, linewidth=0, rasterized=False
    # )  # (0, 0, 0, 1)

    fig, axes = plt.subplots(len(genes) if transpose else 1,
                             1 if transpose else len(genes),
                             figsize=figsize,
                             facecolor=background)
    fig.subplots_adjust(hspace=0, wspace=0)
    clusters_vec = cells_df.loc[gene_df.columns.values].values

    # may also use clusters when transpose
    for igene, gene in enumerate(genes):
        cur_gene_df = pd.DataFrame({
            gene: gene_df.loc[gene, :].values,
            "clusters_": clusters_vec
        })
        cur_gene_df = cur_gene_df.loc[cur_gene_df['clusters_'].isin(clusters)]

        if type == 'violin':
            # use sort here
            sns.violinplot(
                data=cur_gene_df,
                x='clusters_' if transpose else gene,
                y=gene if transpose else "clusters_",
                orient='v' if transpose else 'h',
                order=clusters,  # genes if transpose else
                linewidth=None,
                palette=color_key,
                inner='box',
                scale='width',
                cut=0,
                ax=axes[igene],
                alpha=alpha,
                **kwargs)
            if transpose:
                axes[igene].set_ylim(xmin[igene], xmax[igene])
                axes[igene].set_yticks([])
                axes[igene].set_ylabel(gene,
                                       rotation=rotate_ylabel,
                                       ha='right',
                                       va='center')
            else:
                axes[igene].set_xlim(xmin[igene], xmax[igene])
                axes[igene].set_xticks([])
                axes[igene].set_xlabel(gene,
                                       rotation=rotate_xlabel,
                                       ha='right')

        elif type == 'dot':
            # use sort here
            avg_perc_cluster = cur_gene_df.groupby(
                'clusters_').expression.apply(lambda x: pd.Series(
                    [x.mean(), (x != 0).sum() / len(x)])).unstack()
            avg_perc_cluster.columns = ['avg', 'perc']

            axes[igene].scatter(
                x=clusters if transpose else gene,
                y=gene if transpose else clusters,
                s=avg_perc_cluster.loc[clusters, 'perc'] * point_size,
                lw=2,
                c=avg_perc_cluster.loc[clusters, 'avg'],
                cmap='viridis' if cmap is None else cmap,
                rasterized=False,
                edgecolor=edgecolor,
                linewidth=linewidth,
                alpha=alpha,
            )
        if transpose:
            if igene != len(genes) - 1:
                axes[igene].set_xticks([])
            else:
                axes[igene].set_xticklabels(list(map(str, np.array(clusters))),
                                            rotation=rotate_xlabel,
                                            ha="right")
        else:
            if igene != 0:
                axes[igene].set_yticks([])
            else:
                axes[igene].set_yticklabels(list(map(str, np.array(clusters))),
                                            rotation=rotate_ylabel,
                                            ha="right",
                                            va='center')
        axes[igene].set_xlabel('') if transpose else axes[igene].set_ylabel('')

    if save_show_or_return == "save":
        s_kwargs = {
            "path": None,
            "prefix": 'violin',
            "dpi": None,
            "ext": 'pdf',
            "transparent": True,
            "close": True,
            "verbose": True
        }
        s_kwargs = update_dict(s_kwargs, save_kwargs)

        save_fig(**s_kwargs)
        if background is not None: reset_rcParams()
    elif save_show_or_return == "show":
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            plt.tight_layout()

        plt.show()
        if background is not None: reset_rcParams()
    elif save_show_or_return == "return":
        if background is not None: reset_rcParams()

        return fig, axes
Example #59
0
def figure_edit(axes, parent=None):
    """Edit matplotlib figure options"""
    sep = (None, None)  # separator

    # Get / General
    xmin, xmax = map(float, axes.get_xlim())
    ymin, ymax = map(float, axes.get_ylim())
    general = [('Title', axes.get_title()),
               sep,
               (None, "<b>X-Axis</b>"),
               ('Min', xmin), ('Max', xmax),
               ('Label', axes.get_xlabel()),
               ('Scale', [axes.get_xscale(), 'linear', 'log']),
               sep,
               (None, "<b>Y-Axis</b>"),
               ('Min', ymin), ('Max', ymax),
               ('Label', axes.get_ylabel()),
               ('Scale', [axes.get_yscale(), 'linear', 'log']),
               sep,
               ('(Re-)Generate automatic legend', False),
               ]

    # Save the unit data
    xconverter = axes.xaxis.converter
    yconverter = axes.yaxis.converter
    xunits = axes.xaxis.get_units()
    yunits = axes.yaxis.get_units()

    # Sorting for default labels (_lineXXX, _imageXXX).
    def cmp_key(label):
        match = re.match(r"(_line|_image)(\d+)", label)
        if match:
            return match.group(1), int(match.group(2))
        else:
            return label, 0

    # Get / Curves
    linedict = {}
    for line in axes.get_lines():
        label = line.get_label()
        if label == '_nolegend_':
            continue
        linedict[label] = line
    curves = []

    def prepare_data(d, init):
        """Prepare entry for FormLayout.

        `d` is a mapping of shorthands to style names (a single style may
        have multiple shorthands, in particular the shorthands `None`,
        `"None"`, `"none"` and `""` are synonyms); `init` is one shorthand
        of the initial style.

        This function returns an list suitable for initializing a
        FormLayout combobox, namely `[initial_name, (shorthand,
        style_name), (shorthand, style_name), ...]`.
        """
        # Drop duplicate shorthands from dict (by overwriting them during
        # the dict comprehension).
        name2short = {name: short for short, name in d.items()}
        # Convert back to {shorthand: name}.
        short2name = {short: name for name, short in name2short.items()}
        # Find the kept shorthand for the style specified by init.
        canonical_init = name2short[d[init]]
        # Sort by representation and prepend the initial value.
        return ([canonical_init] +
                sorted(short2name.items(),
                       key=lambda short_and_name: short_and_name[1]))

    curvelabels = sorted(linedict, key=cmp_key)
    for label in curvelabels:
        line = linedict[label]
        color = mcolors.to_hex(
            mcolors.to_rgba(line.get_color(), line.get_alpha()),
            keep_alpha=True)
        ec = mcolors.to_hex(line.get_markeredgecolor(), keep_alpha=True)
        fc = mcolors.to_hex(line.get_markerfacecolor(), keep_alpha=True)
        curvedata = [
            ('Label', label),
            sep,
            (None, '<b>Line</b>'),
            ('Line style', prepare_data(LINESTYLES, line.get_linestyle())),
            ('Draw style', prepare_data(DRAWSTYLES, line.get_drawstyle())),
            ('Width', line.get_linewidth()),
            ('Color (RGBA)', color),
            sep,
            (None, '<b>Marker</b>'),
            ('Style', prepare_data(MARKERS, line.get_marker())),
            ('Size', line.get_markersize()),
            ('Face color (RGBA)', fc),
            ('Edge color (RGBA)', ec)]
        curves.append([curvedata, label, ""])
    # Is there a curve displayed?
    has_curve = bool(curves)

    # Get / Images
    imagedict = {}
    for image in axes.get_images():
        label = image.get_label()
        if label == '_nolegend_':
            continue
        imagedict[label] = image
    imagelabels = sorted(imagedict, key=cmp_key)
    images = []
    cmaps = [(cmap, name) for name, cmap in sorted(cm.cmap_d.items())]
    for label in imagelabels:
        image = imagedict[label]
        cmap = image.get_cmap()
        if cmap not in cm.cmap_d.values():
            cmaps = [(cmap, cmap.name)] + cmaps
        low, high = image.get_clim()
        imagedata = [
            ('Label', label),
            ('Colormap', [cmap.name] + cmaps),
            ('Min. value', low),
            ('Max. value', high),
            ('Interpolation',
             [image.get_interpolation()]
             + [(name, name) for name in sorted(image.iterpnames)])]
        images.append([imagedata, label, ""])
    # Is there an image displayed?
    has_image = bool(images)

    datalist = [(general, "Axes", "")]
    if curves:
        datalist.append((curves, "Curves", ""))
    if images:
        datalist.append((images, "Images", ""))

    def apply_callback(data):
        """This function will be called to apply changes"""
        general = data.pop(0)
        curves = data.pop(0) if has_curve else []
        images = data.pop(0) if has_image else []
        if data:
            raise ValueError("Unexpected field")

        # Set / General
        (title, xmin, xmax, xlabel, xscale, ymin, ymax, ylabel, yscale,
         generate_legend) = general

        if axes.get_xscale() != xscale:
            axes.set_xscale(xscale)
        if axes.get_yscale() != yscale:
            axes.set_yscale(yscale)

        axes.set_title(title)
        axes.set_xlim(xmin, xmax)
        axes.set_xlabel(xlabel)
        axes.set_ylim(ymin, ymax)
        axes.set_ylabel(ylabel)

        # Restore the unit data
        axes.xaxis.converter = xconverter
        axes.yaxis.converter = yconverter
        axes.xaxis.set_units(xunits)
        axes.yaxis.set_units(yunits)
        axes.xaxis._update_axisinfo()
        axes.yaxis._update_axisinfo()

        # Set / Curves
        for index, curve in enumerate(curves):
            line = linedict[curvelabels[index]]
            (label, linestyle, drawstyle, linewidth, color, marker, markersize,
             markerfacecolor, markeredgecolor) = curve
            line.set_label(label)
            line.set_linestyle(linestyle)
            line.set_drawstyle(drawstyle)
            line.set_linewidth(linewidth)
            rgba = mcolors.to_rgba(color)
            line.set_alpha(None)
            line.set_color(rgba)
            if marker is not 'none':
                line.set_marker(marker)
                line.set_markersize(markersize)
                line.set_markerfacecolor(markerfacecolor)
                line.set_markeredgecolor(markeredgecolor)

        # Set / Images
        for index, image_settings in enumerate(images):
            image = imagedict[imagelabels[index]]
            label, cmap, low, high, interpolation = image_settings
            image.set_label(label)
            image.set_cmap(cm.get_cmap(cmap))
            image.set_clim(*sorted([low, high]))
            image.set_interpolation(interpolation)

        # re-generate legend, if checkbox is checked
        if generate_legend:
            draggable = None
            ncol = 1
            if axes.legend_ is not None:
                old_legend = axes.get_legend()
                draggable = old_legend._draggable is not None
                ncol = old_legend._ncol
            new_legend = axes.legend(ncol=ncol)
            if new_legend:
                new_legend.draggable(draggable)

        # Redraw
        figure = axes.get_figure()
        figure.canvas.draw()

    data = formlayout.fedit(datalist, title="Figure options", parent=parent,
                            icon=get_icon('qt4_editor_options.svg'),
                            apply=apply_callback)
    if data is not None:
        apply_callback(data)
Example #60
0
df_all_EDGES_sel = pd.read_sql_query('''
                SELECT u, v, COUNT(*)
                FROM  public.mapmatching_2017
                GROUP BY u, v ''', conn_HAIG)

# make a copy
df_all_EDGES_records = df_all_EDGES_sel

### add colors based on 'records'
vmin = min(df_all_EDGES_records['count'])
vmax = max(df_all_EDGES_records['count'])
# df_all_EDGES_records.iloc[-1] = np.nan
# Try to map values to colors in hex
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
mapper = plt.cm.ScalarMappable(norm=norm, cmap=plt.cm.Reds)  # scales of reds
df_all_EDGES_records['color'] = df_all_EDGES_records['count'].apply(lambda x: mcolors.to_hex(mapper.to_rgba(x)))

# df_all_EDGES_sel = df_all_EDGES_sel[['u','v']]

# filter recover_all_EDGES (geo-dataframe) with df_recover_all_EDGES_sel (dataframe)
# clean_edges_matched_route = pd.merge(df_all_EDGES_sel, gdf_all_EDGES, on=['u', 'v'],how='left')

clean_edges_matched_route = pd.read_sql_query('''
                                      WITH df_all_EDGES_sel AS(
                                      SELECT u, v, COUNT(*)
                                      FROM  mapmatching_2017
                                      GROUP BY u, v)
                                      SELECT df_all_EDGES_sel.u, 
                                             df_all_EDGES_sel.v,
                                             df_all_EDGES_sel.count, 
                                      mapmatching_2017.idtrajectory,