示例#1
0
 def get_colored_edge_weights(infr, graph=None, highlight_reviews=True):
     # Update color and linewidth based on scores/weight
     if graph is None:
         graph = infr.graph
     truth_colors = infr._get_truth_colors()
     if highlight_reviews:
         edges = []
         colors = []
         for edge in graph.edges():
             d = infr.get_edge_data(edge)
             state = d.get('evidence_decision', UNREV)
             meta = d.get('meta_decision', NULL)
             color = truth_colors[state]
             if state not in {POSTV, NEGTV}:
                 # Darken and saturated same/diff edges without visual
                 # evidence
                 if meta == SAME:
                     color = truth_colors[POSTV]
                 if meta == DIFF:
                     color = truth_colors[NEGTV]
                 # color = util.adjust_hsv_of_rgb(
                 #     color, sat_adjust=1, val_adjust=-.3)
             edges.append(edge)
             colors.append(color)
     else:
         edges = list(graph.edges())
         edge_to_weight = nx.get_edge_attributes(graph, 'normscore')
         weights = np.array(
             list(ub.dict_take(edge_to_weight, edges, np.nan)))
         nan_idxs = []
         if len(weights) > 0:
             # give nans threshold value
             nan_idxs = np.where(np.isnan(weights))[0]
             thresh = .5
             weights[nan_idxs] = thresh
         colors = infr.get_colored_weights(weights)
         #print('!! weights = %r' % (len(weights),))
         #print('!! edges = %r' % (len(edges),))
         #print('!! colors = %r' % (len(colors),))
         if len(nan_idxs) > 0:
             for idx in nan_idxs:
                 colors[idx] = util.Color('gray').as01()
     return edges, colors
示例#2
0
def _assign_confusion_vectors(true_dets,
                              pred_dets,
                              bg_weight=1.0,
                              iou_thresh=0.5,
                              bg_cidx=-1,
                              bias=0.0,
                              classes=None,
                              compat='all',
                              prioritize='iou',
                              ignore_classes='ignore',
                              max_dets=None):
    """
    Create confusion vectors for detections by assigning to ground true boxes

    Given predictions and truth for an image return (y_pred, y_true,
    y_score), which is suitable for sklearn classification metrics

    Args:
        true_dets (Detections):
            groundtruth with boxes, classes, and weights

        pred_dets (Detections):
            predictions with boxes, classes, and scores

        iou_thresh (float, default=0.5):
            bounding box overlap iou threshold required for assignment

        bias (float, default=0.0):
            for computing bounding box overlap, either 1 or 0

        gids (List[int], default=None):
            which subset of images ids to compute confusion metrics on. If
            not specified all images are used.

        compat (str, default='all'):
            can be ('ancestors' | 'mutex' | 'all').  determines which pred
            boxes are allowed to match which true boxes. If 'mutex', then
            pred boxes can only match true boxes of the same class. If
            'ancestors', then pred boxes can match true boxes that match or
            have a coarser label. If 'all', then any pred can match any
            true, regardless of its category label.

        prioritize (str, default='iou'):
            can be ('iou' | 'class' | 'correct') determines which box to
            assign to if mutiple true boxes overlap a predicted box.  if
            prioritize is iou, then the true box with maximum iou (above
            iou_thresh) will be chosen.  If prioritize is class, then it will
            prefer matching a compatible class above a higher iou. If
            prioritize is correct, then ancestors of the true class are
            preferred over descendents of the true class, over unreleated
            classes.

        bg_cidx (int, default=-1):
            The index of the background class.  The index used in the truth
            column when a predicted bounding box does not match any true
            bounding box.

        classes (List[str] | kwcoco.CategoryTree):
            mapping from class indices to class names. Can also contain class
            heirarchy information.

        ignore_classes (str | List[str]):
            class name(s) indicating ignore regions

        max_dets (int): maximum number of detections to consider

    TODO:
        - [ ] This is a bottleneck function. An implementation in C / C++ /
        Cython would likely improve the overall system.

        - [ ] Implement crowd truth. Allow multiple predictions to match any
              truth objet marked as "iscrowd".

    Returns:
        dict: with relevant confusion vectors. This keys of this dict can be
            interpreted as columns of a data frame. The `txs` / `pxs` columns
            represent the indexes of the true / predicted annotations that were
            assigned as matching. Additionally each row also contains the true
            and predicted class index, the predicted score, the true weight and
            the iou of the true and predicted boxes. A `txs` value of -1 means
            that the predicted box was not assigned to a true annotation and a
            `pxs` value of -1 means that the true annotation was not assigne to
            any predicted annotation.

    Example:
        >>> # xdoctest: +REQUIRES(module:pandas)
        >>> import pandas as pd
        >>> import kwimage
        >>> # Given a raw numpy representation construct Detection wrappers
        >>> true_dets = kwimage.Detections(
        >>>     boxes=kwimage.Boxes(np.array([
        >>>         [ 0,  0, 10, 10], [10,  0, 20, 10],
        >>>         [10,  0, 20, 10], [20,  0, 30, 10]]), 'tlbr'),
        >>>     weights=np.array([1, 0, .9, 1]),
        >>>     class_idxs=np.array([0, 0, 1, 2]))
        >>> pred_dets = kwimage.Detections(
        >>>     boxes=kwimage.Boxes(np.array([
        >>>         [6, 2, 20, 10], [3,  2, 9, 7],
        >>>         [3,  9, 9, 7],  [3,  2, 9, 7],
        >>>         [2,  6, 7, 7],  [20,  0, 30, 10]]), 'tlbr'),
        >>>     scores=np.array([.5, .5, .5, .5, .5, .5]),
        >>>     class_idxs=np.array([0, 0, 1, 2, 0, 1]))
        >>> bg_weight = 1.0
        >>> compat = 'all'
        >>> iou_thresh = 0.5
        >>> bias = 0.0
        >>> import kwcoco
        >>> classes = kwcoco.CategoryTree.from_mutex(list(range(3)))
        >>> bg_cidx = -1
        >>> y = _assign_confusion_vectors(true_dets, pred_dets, bias=bias,
        >>>                               bg_weight=bg_weight, iou_thresh=iou_thresh,
        >>>                               compat=compat)
        >>> y = pd.DataFrame(y)
        >>> print(y)  # xdoc: +IGNORE_WANT
           pred  true  score  weight     iou  txs  pxs
        0     1     2 0.5000  1.0000  1.0000    3    5
        1     0    -1 0.5000  1.0000 -1.0000   -1    4
        2     2    -1 0.5000  1.0000 -1.0000   -1    3
        3     1    -1 0.5000  1.0000 -1.0000   -1    2
        4     0    -1 0.5000  1.0000 -1.0000   -1    1
        5     0     0 0.5000  0.0000  0.6061    1    0
        6    -1     0 0.0000  1.0000 -1.0000    0   -1
        7    -1     1 0.0000  0.9000 -1.0000    2   -1

    Ignore:
        from xinspect.dynamic_kwargs import get_func_kwargs
        globals().update(get_func_kwargs(_assign_confusion_vectors))

    Example:
        >>> # xdoctest: +REQUIRES(module:pandas)
        >>> import pandas as pd
        >>> from kwcoco.metrics import DetectionMetrics
        >>> dmet = DetectionMetrics.demo(nimgs=1, nclasses=8,
        >>>                              nboxes=(0, 20), n_fp=20,
        >>>                              box_noise=.2, cls_noise=.3)
        >>> classes = dmet.classes
        >>> gid = 0
        >>> true_dets = dmet.true_detections(gid)
        >>> pred_dets = dmet.pred_detections(gid)
        >>> y = _assign_confusion_vectors(true_dets, pred_dets,
        >>>                               classes=dmet.classes,
        >>>                               compat='all', prioritize='class')
        >>> y = pd.DataFrame(y)
        >>> print(y)  # xdoc: +IGNORE_WANT
        >>> y = _assign_confusion_vectors(true_dets, pred_dets,
        >>>                               classes=dmet.classes,
        >>>                               compat='ancestors', iou_thresh=.5)
        >>> y = pd.DataFrame(y)
        >>> print(y)  # xdoc: +IGNORE_WANT
    """
    import kwarray
    valid_compat_keys = {'ancestors', 'mutex', 'all'}
    if compat not in valid_compat_keys:
        raise KeyError(compat)
    if classes is None and compat == 'ancestors':
        compat = 'mutex'

    if compat == 'mutex':
        prioritize = 'iou'

    # Group true boxes by class
    # Keep track which true boxes are unused / not assigned
    unique_tcxs, tgroupxs = kwarray.group_indices(true_dets.class_idxs)
    cx_to_txs = dict(zip(unique_tcxs, tgroupxs))

    unique_pcxs = np.array(sorted(set(pred_dets.class_idxs)))

    if classes is None:
        import kwcoco
        # Build mutually exclusive category tree
        all_cxs = sorted(
            set(map(int, unique_pcxs)) | set(map(int, unique_tcxs)))
        all_cxs = list(range(max(all_cxs) + 1))
        classes = kwcoco.CategoryTree.from_mutex(all_cxs)

    cx_to_ancestors = classes.idx_to_ancestor_idxs()

    if prioritize == 'iou':
        pdist_priority = None  # TODO: cleanup
    else:
        pdist_priority = _fast_pdist_priority(classes, prioritize)

    if compat == 'mutex':
        # assume classes are mutually exclusive if hierarchy is not given
        cx_to_matchable_cxs = {cx: [cx] for cx in unique_pcxs}
    elif compat == 'ancestors':
        cx_to_matchable_cxs = {
            cx: sorted([cx] + sorted(
                ub.take(classes.node_to_idx,
                        nx.ancestors(classes.graph, classes.idx_to_node[cx]))))
            for cx in unique_pcxs
        }
    elif compat == 'all':
        cx_to_matchable_cxs = {cx: unique_tcxs for cx in unique_pcxs}
    else:
        raise KeyError(compat)

    if compat == 'all':
        # In this case simply run the full pairwise iou
        common_true_idxs = np.arange(len(true_dets))
        cx_to_matchable_txs = {cx: common_true_idxs for cx in unique_pcxs}
        common_ious = pred_dets.boxes.ious(true_dets.boxes, bias=bias)
        # common_ious = pred_dets.boxes.ious(true_dets.boxes, impl='c', bias=bias)
        iou_lookup = dict(enumerate(common_ious))
    else:
        # For each pred-category find matchable true-indices
        cx_to_matchable_txs = {}
        for cx, compat_cx in cx_to_matchable_cxs.items():
            matchable_cxs = cx_to_matchable_cxs[cx]
            compat_txs = ub.dict_take(cx_to_txs, matchable_cxs, default=[])
            compat_txs = np.array(sorted(ub.flatten(compat_txs)), dtype=int)
            cx_to_matchable_txs[cx] = compat_txs

        # Batch up the IOU pre-computation between compatible truths / preds
        iou_lookup = {}
        unique_pred_cxs, pgroupxs = kwarray.group_indices(pred_dets.class_idxs)
        for cx, pred_idxs in zip(unique_pred_cxs, pgroupxs):
            true_idxs = cx_to_matchable_txs[cx]
            ious = pred_dets.boxes[pred_idxs].ious(true_dets.boxes[true_idxs],
                                                   bias=bias)
            _px_to_iou = dict(zip(pred_idxs, ious))
            iou_lookup.update(_px_to_iou)

    iou_thresh_list = ([iou_thresh]
                       if not ub.iterable(iou_thresh) else iou_thresh)

    iou_thresh_to_y = {}
    for iou_thresh_ in iou_thresh_list:
        isvalid_lookup = {
            px: ious > iou_thresh_
            for px, ious in iou_lookup.items()
        }

        y = _critical_loop(true_dets,
                           pred_dets,
                           iou_lookup,
                           isvalid_lookup,
                           cx_to_matchable_txs,
                           bg_weight,
                           prioritize,
                           iou_thresh_,
                           pdist_priority,
                           cx_to_ancestors,
                           bg_cidx,
                           ignore_classes=ignore_classes,
                           max_dets=max_dets)
        iou_thresh_to_y[iou_thresh_] = y

    if ub.iterable(iou_thresh):
        return iou_thresh_to_y
    else:
        return y
示例#3
0
def multi_plot(xdata=None, ydata=None, xydata=None, **kwargs):
    r"""
    plots multiple lines, bars, etc...

    One function call that concisely describes the all of the most commonly
    used parameters needed when plotting a bar / line char. This is especially
    useful when multiple plots are needed in the same domain.

    Args:
        xdata (List[ndarray] | Dict[str, ndarray] | ndarray):
            x-coordinate data common to all y-coordinate values or xdata for
            each line/bar in ydata.  Mutually exclusive with xydata.

        ydata (List[ndarray] | Dict[str, ndarray] | ndarary):
            y-coordinate values for each line/bar to plot. Can also
            be just a single ndarray of scalar values. Mutually exclusive with
            xydata.

        xydata (Dict[str, Tuple[ndarray, ndarray]]):
            mapping from labels to a tuple of xdata and ydata for a each line.

        **kwargs:

            fnum (int):
                figure number to draw on

            pnum (Tuple[int, int, int]):
                plot number to draw on within the figure: e.g. (1, 1, 1)

            label (List|Dict): if you specified ydata as a List[ndarray]
                this is the label for each line in that list. Note this is
                unnecessary if you specify input as a dictionary mapping labels
                to lines.

            color (str|List|Dict): either a special color code, a single color,
                or a color for each item in ydata. In the later case, this
                should be specified as either a list or a dict depending on how
                ydata was specified.

            marker (str|List|Dict): type of matplotlib marker to use at every
                data point. Can be specified for all lines jointly or for each
                line independently.

            transpose (bool, default=False): swaps x and y data.

            kind (str, default='plot'):
                The kind of plot. Can either be 'plot' or 'bar'.
                We parse these other kwargs if:
                    if kind='plot':
                        spread
                    if kind='bar':
                        stacked, width

            Misc:
                use_legend (bool): ...
                legend_loc (str):
                    one of 'best', 'upper right', 'upper left', 'lower left',
                    'lower right', 'right', 'center left', 'center right',
                    'lower center', or 'upper center'.

            Layout:
                xlabel (str): label for x-axis
                ylabel (str): label for y-axis
                title (str): title for the axes
                figtitle (str): title for the figure

                xscale (str): can be one of [linear, log, logit, symlog]
                yscale (str): can be one of [linear, log, logit, symlog]

                xlim (Tuple[float, float]): low and high x-limit of axes
                ylim (Tuple[float, float]): low and high y-limit of axes
                xmin (float): low x-limit of axes, mutex with xlim
                xmax (float): high x-limit of axes, mutex with xlim
                ymin (float): low y-limit of axes, mutex with ylim
                ymax (float): high y-limit of axes, mutex with ylim

                titlesize (float): ...
                legendsize (float): ...
                labelsize (float): ...

            Grid:
                gridlinewidth (float): ...
                gridlinestyle (str): ...
            Ticks:
                num_xticks (int): number of x ticks
                num_yticks (int): number of y ticks
                tickwidth (float): ...
                ticklength (float): ...
                ticksize (float): ...
                xticklabels (list): list of x-tick labels, overrides num_xticks
                yticklabels (list): list of y-tick labels, overrides num_yticks
                xtick_rotation (float): xtick rotation in degrees
                ytick_rotation (float): ytick rotation in degrees
            Data:
                spread (List | Dict): Plots a spread around plot lines usually
                    indicating standard deviation

                markersize (float|List|Dict): marker size for all or each plot

                markeredgewidth (float|List|Dict): marker edge width for all or each plot

                linewidth (float|List|Dict): line width for all or each plot

                linestyle (str|List|Dict): line style for all or each plot

    Notes:
        any plot_kw key can be a scalar (corresponding to all ydatas),
        a list if ydata was specified as a list, or a dict if ydata was
        specified as a dict.

        plot_kw_keys = ['label', 'color', 'marker', 'markersize',
            'markeredgewidth', 'linewidth', 'linestyle']

    Returns:
        matplotlib.axes.Axes: ax : the axes that was drawn on

    References:
        matplotlib.org/examples/api/barchart_demo.html

    Example:
        >>> import kwplot
        >>> kwplot.autompl()
        >>> # The new way to use multi_plot is to pass ydata as a dict of lists
        >>> ydata = {
        >>>     'spamΣ': [1, 1, 2, 3, 5, 8, 13],
        >>>     'eggs': [3, 3, 3, 3, 3, np.nan, np.nan],
        >>>     'jamµ': [5, 3, np.nan, 1, 2, np.nan, np.nan],
        >>>     'pram': [4, 2, np.nan, 0, 0, np.nan, 1],
        >>> }
        >>> ax = kwplot.multi_plot(ydata=ydata, title='ΣΣΣµµµ',
        >>>                      xlabel='\nfdsΣΣΣµµµ', linestyle='--')
        >>> kwplot.show_if_requested()

    Example:
        >>> # Old way to use multi_plot is a list of lists
        >>> import kwplot
        >>> kwplot.autompl()
        >>> xdata = [1, 2, 3, 4, 5]
        >>> ydata_list = [[1, 2, 3, 4, 5], [3, 3, 3, 3, 3], [5, 4, np.nan, 2, 1], [4, 3, np.nan, 1, 0]]
        >>> kwargs = {'label': ['spamΣ', 'eggs', 'jamµ', 'pram'],  'linestyle': '-'}
        >>> #ax = multi_plot(xdata, ydata_list, title='$\phi_1(\\vec{x})$', xlabel='\nfds', **kwargs)
        >>> ax = multi_plot(xdata, ydata_list, title='ΣΣΣµµµ', xlabel='\nfdsΣΣΣµµµ', **kwargs)
        >>> kwplot.show_if_requested()

    Example:
        >>> # Simple way to use multi_plot is to pass xdata and ydata exactly
        >>> # like you would use plt.plot
        >>> import kwplot
        >>> kwplot.autompl()
        >>> ax = multi_plot([1, 2, 3], [4, 5, 6], fnum=4, label='foo')
        >>> kwplot.show_if_requested()

    Example:
        >>> import kwplot
        >>> kwplot.autompl()
        >>> xydata = {'a': ([0, 1, 2], [0, 1, 2]), 'b': ([0, 2, 4], [2, 1, 0])}
        >>> ax = kwplot.multi_plot(xydata=xydata, fnum=4)
        >>> kwplot.show_if_requested()

    Example:
        >>> import kwplot
        >>> kwplot.autompl()
        >>> ydata = {'a': [0, 1, 2], 'b': [1, 2, 1], 'c': [4, 4, 4, 3, 2]}
        >>> kwargs = {
        >>>     'spread': {'a': [.2, .3, .1], 'b': .2},
        >>>     'xlim': (-1, 5),
        >>>     'xticklabels': ['foo', 'bar'],
        >>>     'xtick_rotation': 90,
        >>> }
        >>> ax = kwplot.multi_plot(ydata=ydata, fnum=4, **kwargs)
        >>> kwplot.show_if_requested()

    Ignore:
        >>> import kwplot
        >>> kwplot.autompl()
        >>> ydata = {
        >>>     str(i): np.random.rand(100) + i for i in range(30)
        >>> }
        >>> ax = kwplot.multi_plot(ydata=ydata, fnum=1, doclf=True)
        >>> kwplot.show_if_requested()
    """
    import matplotlib as mpl
    from matplotlib import pyplot as plt

    # Initial integration with mpl rcParams standards
    mplrc = mpl.rcParams
    # mplrc.update({
    #     # 'legend.fontsize': custom_figure.LEGEND_SIZE,
    #     # 'legend.framealpha':
    #     # 'axes.titlesize': custom_figure.TITLE_SIZE,
    #     # 'axes.labelsize': custom_figure.LABEL_SIZE,
    #     # 'legend.facecolor': 'w',
    #     # 'font.family': 'sans-serif',
    #     # 'xtick.labelsize': custom_figure.TICK_SIZE,
    #     # 'ytick.labelsize': custom_figure.TICK_SIZE,
    # })
    if 'rcParams' in kwargs:
        mplrc = mplrc.copy()
        mplrc.update(kwargs['rcParams'])

    if xydata is not None:
        if xdata is not None or ydata is not None:
            raise ValueError('Cannot specify xydata with xdata or ydata')
        if isinstance(xydata, dict):
            xdata = ub.odict((k, np.array(xy[0])) for k, xy in xydata.items())
            ydata = ub.odict((k, np.array(xy[1])) for k, xy in xydata.items())
        else:
            raise ValueError('Only supports xydata as Dict at the moment')

    if bool('label' in kwargs) and bool('label_list' in kwargs):
        raise ValueError('Specify either label or label_list')

    if isinstance(ydata, dict):
        # Case where ydata is a dictionary
        if isinstance(xdata, six.string_types):
            # Special-er case where xdata is specified in ydata
            xkey = xdata
            ykeys = set(ydata.keys()) - {xkey}
            xdata = ydata[xkey]
        else:
            ykeys = list(ydata.keys())
        # Normalize input into ydata_list
        ydata_list = list(ub.take(ydata, ykeys))

        default_label_list = kwargs.pop('label', ykeys)
        kwargs['label_list'] = kwargs.get('label_list', default_label_list)
    else:
        # ydata should be a List[ndarray] or an ndarray
        ydata_list = ydata
        ykeys = None

    # allow ydata_list to be passed without a container
    if is_list_of_scalars(ydata_list):
        ydata_list = [np.array(ydata_list)]

    if xdata is None:
        xdata = list(range(max(map(len, ydata_list))))

    num_lines = len(ydata_list)

    # Transform xdata into xdata_list
    if isinstance(xdata, dict):
        xdata_list = [np.array(xdata[k], copy=True) for k in ykeys]
    elif is_list_of_lists(xdata):
        xdata_list = [np.array(xd, copy=True) for xd in xdata]
    else:
        xdata_list = [np.array(xdata, copy=True)] * num_lines

    fnum = mpl_core.ensure_fnum(kwargs.get('fnum', None))
    pnum = kwargs.get('pnum', None)
    kind = kwargs.get('kind', 'plot')
    transpose = kwargs.get('transpose', False)

    def parsekw_list(key,
                     kwargs,
                     num_lines=num_lines,
                     ykeys=ykeys,
                     default=ub.NoParam):
        """
        Return properties that corresponds with ydata_list.

        Searches kwargs for several keys based on the base key and finds either
        a scalar, list, or dict and coerces this into a list of properties that
        corresonds with the ydata_list.
        """
        if key in kwargs:
            val_list = kwargs[key]
        elif key + '_list' in kwargs:
            # warnings.warn('*_list is depricated, just use kwarg {}'.format(key))
            val_list = kwargs[key + '_list']
        elif key + 's' in kwargs:
            # hack, multiple ways to do something
            warnings.warn('*s depricated, just use kwarg {}'.format(key))
            val_list = kwargs[key + 's']
        else:
            val_list = None

        if val_list is not None:
            if isinstance(val_list, dict):
                # Extract propertly ordered dictionary values
                if ykeys is None:
                    raise ValueError(
                        'Kwarg {!r} was a dict, but ydata was not'.format(key))
                else:
                    if default is ub.NoParam:
                        val_list = [val_list[key] for key in ykeys]
                    else:
                        val_list = [
                            val_list.get(key, default) for key in ykeys
                        ]

            if not isinstance(val_list, list):
                # Coerce a scalar value into a list
                val_list = [val_list] * num_lines
        return val_list

    if kind == 'plot':
        if 'marker' not in kwargs:
            # kwargs['marker'] = mplrc['lines.marker']
            kwargs['marker'] = 'distinct'
            # kwargs['marker'] = 'cycle'

        if isinstance(kwargs['marker'], six.string_types):
            if kwargs['marker'] == 'distinct':
                kwargs['marker'] = mpl_core.distinct_markers(num_lines)
            elif kwargs['marker'] == 'cycle':
                # Note the length of marker and linestyle cycles should be
                # relatively prime.
                # https://matplotlib.org/api/markers_api.html
                marker_cycle = ['.', '*', 'x']
                kwargs['marker'] = [
                    marker_cycle[i % len(marker_cycle)]
                    for i in range(num_lines)
                ]
            # else:
            #     raise KeyError(kwargs['marker'])

        if 'linestyle' not in kwargs:
            # kwargs['linestyle'] = 'distinct'
            kwargs['linestyle'] = mplrc['lines.linestyle']
            # kwargs['linestyle'] = 'cycle'

        if isinstance(kwargs['linestyle'], six.string_types):
            if kwargs['linestyle'] == 'cycle':
                # https://matplotlib.org/gallery/lines_bars_and_markers/line_styles_reference.html
                linestyle_cycle = ['solid', 'dashed', 'dashdot', 'dotted']
                kwargs['linestyle'] = [
                    linestyle_cycle[i % len(linestyle_cycle)]
                    for i in range(num_lines)
                ]

    if 'color' not in kwargs:
        # kwargs['color'] = 'jet'
        # kwargs['color'] = 'gist_rainbow'
        kwargs['color'] = 'distinct'

    if isinstance(kwargs['color'], six.string_types):
        if kwargs['color'] == 'distinct':
            kwargs['color'] = mpl_core.distinct_colors(num_lines, randomize=0)
        else:
            cm = plt.get_cmap(kwargs['color'])
            kwargs['color'] = [cm(i / num_lines) for i in range(num_lines)]

    # Parse out arguments to ax.plot
    plot_kw_keys = [
        'label', 'color', 'marker', 'markersize', 'markeredgewidth',
        'linewidth', 'linestyle', 'alpha'
    ]
    # hackish / extra args that dont directly get passed to plt.plot
    extra_plot_kw_keys = ['spread_alpha', 'autolabel', 'edgecolor', 'fill']
    plot_kw_keys += extra_plot_kw_keys
    plot_ks_vals = [parsekw_list(key, kwargs) for key in plot_kw_keys]
    plot_list_kw = dict([(key, vals)
                         for key, vals in zip(plot_kw_keys, plot_ks_vals)
                         if vals is not None])

    if kind == 'plot':
        if 'spread_alpha' not in plot_list_kw:
            plot_list_kw['spread_alpha'] = [.2] * num_lines

    if kind == 'bar':
        # Remove non-bar kwargs
        for key in [
                'markeredgewidth', 'linewidth', 'marker', 'markersize',
                'linestyle'
        ]:
            plot_list_kw.pop(key, None)

        stacked = kwargs.get('stacked', False)
        width_key = 'height' if transpose else 'width'
        if 'width_list' in kwargs:
            plot_list_kw[width_key] = kwargs['width_list']
        else:
            width = kwargs.get('width', .9)
            # if width is None:
            #     # HACK: need variable width
            #     # width = np.mean(np.diff(xdata_list[0]))
            #     width = .9
            if not stacked:
                width /= num_lines
            #plot_list_kw['orientation'] = ['horizontal'] * num_lines
            plot_list_kw[width_key] = [width] * num_lines

    spread_list = parsekw_list('spread', kwargs, default=None)

    # nest into a list of dicts for each line in the multiplot
    valid_keys = list(set(plot_list_kw.keys()) - set(extra_plot_kw_keys))
    valid_vals = list(ub.dict_take(plot_list_kw, valid_keys))
    plot_kw_list = [dict(zip(valid_keys, vals)) for vals in zip(*valid_vals)]

    extra_kw_keys = [key for key in extra_plot_kw_keys if key in plot_list_kw]
    extra_kw_vals = list(ub.dict_take(plot_list_kw, extra_kw_keys))
    extra_kw_list = [
        dict(zip(extra_kw_keys, vals)) for vals in zip(*extra_kw_vals)
    ]

    # Get passed in axes or setup a new figure
    ax = kwargs.get('ax', None)
    if ax is None:
        # NOTE: This is slow, can we speed it up?
        doclf = kwargs.get('doclf', False)
        fig = mpl_core.figure(fnum=fnum, pnum=pnum, docla=False, doclf=doclf)
        ax = fig.gca()
    else:
        plt.sca(ax)
        fig = ax.figure

    # +---------------
    # Draw plot lines
    ydata_list = [np.array(ydata) for ydata in ydata_list]

    if transpose:
        if kind == 'bar':
            plot_func = ax.barh
        elif kind == 'plot':

            def plot_func(_x, _y, **kw):
                return ax.plot(_y, _x, **kw)
    else:
        plot_func = getattr(ax, kind)  # usually ax.plot

    if len(ydata_list) > 0:
        # raise ValueError('no ydata')
        _iter = enumerate(
            zip_longest(xdata_list, ydata_list, plot_kw_list, extra_kw_list))
        for count, (_xdata, _ydata, plot_kw, extra_kw) in _iter:
            _ydata = _ydata[0:len(_xdata)]
            _xdata = _xdata[0:len(_ydata)]
            ymask = np.isfinite(_ydata)
            ydata_ = _ydata.compress(ymask)
            xdata_ = _xdata.compress(ymask)
            if kind == 'bar':
                if stacked:
                    # Plot bars on top of each other
                    xdata_ = xdata_
                else:
                    # Plot bars side by side
                    baseoffset = (width * num_lines) / 2
                    lineoffset = (width * count)
                    offset = baseoffset - lineoffset  # Fixeme for more histogram bars
                    xdata_ = xdata_ - offset
                # width_key = 'height' if transpose else 'width'
                # plot_kw[width_key] = np.diff(xdata)
            objs = plot_func(xdata_, ydata_, **plot_kw)

            if kind == 'bar':
                if extra_kw is not None and 'edgecolor' in extra_kw:
                    for rect in objs:
                        rect.set_edgecolor(extra_kw['edgecolor'])
                if extra_kw is not None and extra_kw.get('autolabel', False):
                    # FIXME: probably a more cannonical way to include bar
                    # autolabeling with tranpose support, but this is a hack that
                    # works for now
                    for rect in objs:
                        if transpose:
                            numlbl = width = rect.get_width()
                            xpos = width + (
                                (_xdata.max() - _xdata.min()) * .005)
                            ypos = rect.get_y() + rect.get_height() / 2.
                            ha, va = 'left', 'center'
                        else:
                            numlbl = height = rect.get_height()
                            xpos = rect.get_x() + rect.get_width() / 2.
                            ypos = 1.05 * height
                            ha, va = 'center', 'bottom'
                        barlbl = '%.3f' % (numlbl, )
                        ax.text(xpos, ypos, barlbl, ha=ha, va=va)

            if kind == 'plot' and extra_kw.get('fill', False):
                ax.fill_between(_xdata,
                                ydata_,
                                alpha=plot_kw.get('alpha', 1.0),
                                color=plot_kw.get('color',
                                                  None))  # , zorder=0)

            if spread_list is not None:
                # Plots a spread around plot lines usually indicating standard
                # deviation
                _xdata = np.array(_xdata)
                _spread = spread_list[count]
                if _spread is not None:
                    if not ub.iterable(_spread):
                        _spread = [_spread] * len(ydata_)
                    ydata_ave = np.array(ydata_)
                    y_data_dev = np.array(_spread)
                    y_data_max = ydata_ave + y_data_dev
                    y_data_min = ydata_ave - y_data_dev
                    ax = plt.gca()
                    spread_alpha = extra_kw['spread_alpha']
                    ax.fill_between(_xdata,
                                    y_data_min,
                                    y_data_max,
                                    alpha=spread_alpha,
                                    color=plot_kw.get('color',
                                                      None))  # , zorder=0)
        ydata = _ydata  # HACK
        xdata = _xdata  # HACK
    # L________________

    #max_y = max(np.max(y_data), max_y)
    #min_y = np.min(y_data) if min_y is None else min(np.min(y_data), min_y)

    if transpose:
        #xdata_list = ydata_list
        ydata = xdata

        # Hack / Fix any transpose issues
        def transpose_key(key):
            if key.startswith('x'):
                return 'y' + key[1:]
            elif key.startswith('y'):
                return 'x' + key[1:]
            elif key.startswith('num_x'):
                # hackier, fixme to use regex or something
                return 'num_y' + key[5:]
            elif key.startswith('num_y'):
                # hackier, fixme to use regex or something
                return 'num_x' + key[5:]
            else:
                return key

        kwargs = {transpose_key(key): val for key, val in kwargs.items()}

    # Setup axes labeling
    title = kwargs.get('title', None)
    xlabel = kwargs.get('xlabel', '')
    ylabel = kwargs.get('ylabel', '')

    def none_or_unicode(text):
        return None if text is None else ub.ensure_unicode(text)

    xlabel = none_or_unicode(xlabel)
    ylabel = none_or_unicode(ylabel)
    title = none_or_unicode(title)

    titlesize = kwargs.get('titlesize', mplrc['axes.titlesize'])
    labelsize = kwargs.get('labelsize', mplrc['axes.labelsize'])
    legendsize = kwargs.get('legendsize', mplrc['legend.fontsize'])
    xticksize = kwargs.get('ticksize', mplrc['xtick.labelsize'])
    yticksize = kwargs.get('ticksize', mplrc['ytick.labelsize'])
    family = kwargs.get('fontfamily', mplrc['font.family'])

    tickformat = kwargs.get('tickformat', None)
    ytickformat = kwargs.get('ytickformat', tickformat)
    xtickformat = kwargs.get('xtickformat', tickformat)

    # 'DejaVu Sans','Verdana', 'Arial'
    weight = kwargs.get('fontweight', None)
    if weight is None:
        weight = 'normal'

    labelkw = {
        'fontproperties':
        mpl.font_manager.FontProperties(weight=weight,
                                        family=family,
                                        size=labelsize)
    }
    ax.set_xlabel(xlabel, **labelkw)
    ax.set_ylabel(ylabel, **labelkw)

    tick_fontprop = mpl.font_manager.FontProperties(family=family,
                                                    weight=weight)

    if tick_fontprop is not None:
        # NOTE: This is slow, can we speed it up?
        for ticklabel in ax.get_xticklabels():
            ticklabel.set_fontproperties(tick_fontprop)
        for ticklabel in ax.get_yticklabels():
            ticklabel.set_fontproperties(tick_fontprop)
    if xticksize is not None:
        for ticklabel in ax.get_xticklabels():
            ticklabel.set_fontsize(xticksize)
    if yticksize is not None:
        for ticklabel in ax.get_yticklabels():
            ticklabel.set_fontsize(yticksize)

    if xtickformat is not None:
        # mpl.ticker.StrMethodFormatter  # new style
        # mpl.ticker.FormatStrFormatter  # old style
        ax.xaxis.set_major_formatter(
            mpl.ticker.FormatStrFormatter(xtickformat))
    if ytickformat is not None:
        ax.yaxis.set_major_formatter(
            mpl.ticker.FormatStrFormatter(ytickformat))

    xtick_kw = ytick_kw = {
        'width': kwargs.get('tickwidth', None),
        'length': kwargs.get('ticklength', None),
    }
    xtick_kw = {k: v for k, v in xtick_kw.items() if v is not None}
    ytick_kw = {k: v for k, v in ytick_kw.items() if v is not None}
    ax.xaxis.set_tick_params(**xtick_kw)
    ax.yaxis.set_tick_params(**ytick_kw)

    #ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%d'))

    # Setup axes limits
    if 'xlim' in kwargs:
        xlim = kwargs['xlim']
        if xlim is not None:
            if 'xmin' not in kwargs and 'xmax' not in kwargs:
                kwargs['xmin'] = xlim[0]
                kwargs['xmax'] = xlim[1]
            else:
                raise ValueError('use xmax, xmin instead of xlim')
    if 'ylim' in kwargs:
        ylim = kwargs['ylim']
        if ylim is not None:
            if 'ymin' not in kwargs and 'ymax' not in kwargs:
                kwargs['ymin'] = ylim[0]
                kwargs['ymax'] = ylim[1]
            else:
                raise ValueError('use ymax, ymin instead of ylim')

    xmin = kwargs.get('xmin', ax.get_xlim()[0])
    xmax = kwargs.get('xmax', ax.get_xlim()[1])
    ymin = kwargs.get('ymin', ax.get_ylim()[0])
    ymax = kwargs.get('ymax', ax.get_ylim()[1])

    text_type = six.text_type

    if text_type(xmax) == 'data':
        xmax = max([xd.max() for xd in xdata_list])
    if text_type(xmin) == 'data':
        xmin = min([xd.min() for xd in xdata_list])

    # Setup axes ticks
    num_xticks = kwargs.get('num_xticks', None)
    num_yticks = kwargs.get('num_yticks', None)

    if num_xticks is not None:
        if xdata.dtype.kind == 'i':
            xticks = np.linspace(np.ceil(xmin), np.floor(xmax),
                                 num_xticks).astype(np.int32)
        else:
            xticks = np.linspace((xmin), (xmax), num_xticks)
        ax.set_xticks(xticks)
    if num_yticks is not None:
        if ydata.dtype.kind == 'i':
            yticks = np.linspace(np.ceil(ymin), np.floor(ymax),
                                 num_yticks).astype(np.int32)
        else:
            yticks = np.linspace((ymin), (ymax), num_yticks)
        ax.set_yticks(yticks)

    force_xticks = kwargs.get('force_xticks', None)
    if force_xticks is not None:
        xticks = np.array(sorted(ax.get_xticks().tolist() + force_xticks))
        ax.set_xticks(xticks)

    yticklabels = kwargs.get('yticklabels', None)
    if yticklabels is not None:
        # Hack ONLY WORKS WHEN TRANSPOSE = True
        # Overrides num_yticks
        missing_labels = max(len(ydata) - len(yticklabels), 0)
        yticklabels_ = yticklabels + [''] * missing_labels
        ax.set_yticks(ydata)
        ax.set_yticklabels(yticklabels_)

    xticklabels = kwargs.get('xticklabels', None)
    if xticklabels is not None:
        # Overrides num_xticks
        missing_labels = max(len(xdata) - len(xticklabels), 0)
        xticklabels_ = xticklabels + [''] * missing_labels
        ax.set_xticks(xdata)
        ax.set_xticklabels(xticklabels_)

    xticks = kwargs.get('xticks', None)
    if xticks is not None:
        ax.set_xticks(xticks)

    yticks = kwargs.get('yticks', None)
    if yticks is not None:
        ax.set_yticks(yticks)

    xtick_rotation = kwargs.get('xtick_rotation', None)
    if xtick_rotation is not None:
        [lbl.set_rotation(xtick_rotation) for lbl in ax.get_xticklabels()]
    ytick_rotation = kwargs.get('ytick_rotation', None)
    if ytick_rotation is not None:
        [lbl.set_rotation(ytick_rotation) for lbl in ax.get_yticklabels()]

    # Axis padding
    xpad = kwargs.get('xpad', None)
    ypad = kwargs.get('ypad', None)
    xpad_factor = kwargs.get('xpad_factor', None)
    ypad_factor = kwargs.get('ypad_factor', None)
    if xpad is None and xpad_factor is not None:
        xpad = (xmax - xmin) * xpad_factor
    if ypad is None and ypad_factor is not None:
        ypad = (ymax - ymin) * ypad_factor
    xpad = 0 if xpad is None else xpad
    ypad = 0 if ypad is None else ypad
    ypad_high = kwargs.get('ypad_high', ypad)
    ypad_low = kwargs.get('ypad_low', ypad)
    xpad_high = kwargs.get('xpad_high', xpad)
    xpad_low = kwargs.get('xpad_low', xpad)
    xmin, xmax = (xmin - xpad_low), (xmax + xpad_high)
    ymin, ymax = (ymin - ypad_low), (ymax + ypad_high)
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    xscale = kwargs.get('xscale', None)
    yscale = kwargs.get('yscale', None)
    if yscale is not None:
        ax.set_yscale(yscale)
    if xscale is not None:
        ax.set_xscale(xscale)

    gridlinestyle = kwargs.get('gridlinestyle', None)
    gridlinewidth = kwargs.get('gridlinewidth', None)
    gridlines = ax.get_xgridlines() + ax.get_ygridlines()
    if gridlinestyle:
        for line in gridlines:
            line.set_linestyle(gridlinestyle)
    if gridlinewidth:
        for line in gridlines:
            line.set_linewidth(gridlinewidth)

    # Setup title
    if title is not None:
        titlekw = {
            'fontproperties':
            mpl.font_manager.FontProperties(family=family,
                                            weight=weight,
                                            size=titlesize)
        }
        ax.set_title(title, **titlekw)

    use_legend = kwargs.get('use_legend', 'label' in valid_keys)
    legend_loc = kwargs.get('legend_loc', mplrc['legend.loc'])
    legend_alpha = kwargs.get('legend_alpha', mplrc['legend.framealpha'])
    if use_legend:
        legendkw = {
            'alpha':
            legend_alpha,
            'fontproperties':
            mpl.font_manager.FontProperties(family=family,
                                            weight=weight,
                                            size=legendsize)
        }
        mpl_core.legend(loc=legend_loc, ax=ax, **legendkw)

    figtitle = kwargs.get('figtitle', None)

    if figtitle is not None:
        # mplrc['figure.titlesize'] TODO?
        mpl_core.set_figtitle(figtitle,
                              fontfamily=family,
                              fontweight=weight,
                              size=kwargs.get('figtitlesize'))

    # TODO: return better info
    return ax
示例#4
0
def test_dict_take():
    # There was a bug in 0.7.0 where iterable keys would be exhausted too soon
    keys_list = list(range(10))
    dict_ = {k: k for k in keys_list}
    got = list(ub.dict_take(dict_, keys_list))
    assert got == keys_list
示例#5
0
def 字典_取值(字典, key, default=None):
    return list(ub.dict_take(字典, key, default))
示例#6
0
def make_agraph(graph_, inplace=False):
    import pygraphviz
    patch_pygraphviz()

    if not inplace:
        graph_ = graph_.copy()

    # Convert to agraph format
    num_nodes = len(graph_)
    LARGE_GRAPH = 100
    is_large = num_nodes > LARGE_GRAPH

    if is_large:
        print('Making agraph for large graph %d nodes. '
              'May take time' % (num_nodes))

    # nx_ensure_agraph_color(graph_)
    # Reduce size to be in inches not pixels
    # FIXME: make robust to param settings
    # Hack to make the w/h of the node take thae max instead of
    # dot which takes the minimum
    shaped_nodes = [n for n, d in graph_.nodes(data=True) if 'width' in d]
    node_dict = graph_.nodes
    node_attrs = ub.dict_take(node_dict, shaped_nodes)

    width_px = np.array([n['width'] for n in node_attrs])
    height_px = np.array([n['height'] for n in node_attrs])
    scale = np.array([n.get('scale', 1.0) for n in node_attrs])

    inputscale = 72.0
    width_in = width_px / inputscale * scale
    height_in = height_px / inputscale * scale
    width_in_dict = dict(zip(shaped_nodes, width_in))
    height_in_dict = dict(zip(shaped_nodes, height_in))

    nx.set_node_attributes(graph_, name='width', values=width_in_dict)
    nx.set_node_attributes(graph_, name='height', values=height_in_dict)
    nx_delete_node_attr(graph_, name='scale')

    # Check for any nodes with groupids
    node_to_groupid = nx.get_node_attributes(graph_, 'groupid')
    if node_to_groupid:
        groupid_to_nodes = ub.group_items(*zip(*node_to_groupid.items()))
    else:
        groupid_to_nodes = {}
    # Initialize agraph format
    nx_delete_None_edge_attr(graph_)
    agraph = nx.nx_agraph.to_agraph(graph_)
    # Add subgraphs labels
    # TODO: subgraph attrs
    group_attrs = graph_.graph.get('groupattrs', {})
    for groupid, nodes in groupid_to_nodes.items():
        # subgraph_attrs = {}
        subgraph_attrs = group_attrs.get(groupid, {}).copy()
        cluster_flag = True
        # FIXME: make this more natural to specify
        if 'cluster' in subgraph_attrs:
            cluster_flag = subgraph_attrs['cluster']
            del subgraph_attrs['cluster']
        name = groupid
        if cluster_flag:
            # graphviz treast subgraphs labeld with cluster differently
            name = 'cluster_' + groupid
        else:
            name = groupid
        agraph.add_subgraph(nodes, name, **subgraph_attrs)

    for node in graph_.nodes():
        anode = pygraphviz.Node(agraph, node)
        # TODO: Generally fix node positions
        ptstr_ = anode.attr['pos']
        if (ptstr_ is not None and len(ptstr_) > 0
                and not ptstr_.endswith('!')):
            ptstr = ptstr_.strip('[]').strip(' ').strip('()')
            ptstr_list = [x.rstrip(',') for x in re.split(r'\s+', ptstr)]
            pt_list = list(map(float, ptstr_list))
            pt_arr = np.array(pt_list) / inputscale
            new_ptstr_list = list(map(str, pt_arr))
            new_ptstr_ = ','.join(new_ptstr_list)
            if anode.attr['pin'] is True:
                anode.attr['pin'] = 'true'
            if anode.attr['pin'] == 'true':
                new_ptstr = new_ptstr_ + '!'
            else:
                new_ptstr = new_ptstr_
            anode.attr['pos'] = new_ptstr

    if graph_.graph.get('ignore_labels', False):
        for node in graph_.nodes():
            anode = pygraphviz.Node(agraph, node)
            if 'label' in anode.attr:
                try:
                    del anode.attr['label']
                except KeyError:
                    pass
    return agraph