def get_colored_edge_weights(infr, graph=None, highlight_reviews=True): # Update color and linewidth based on scores/weight if graph is None: graph = infr.graph truth_colors = infr._get_truth_colors() if highlight_reviews: edges = [] colors = [] for edge in graph.edges(): d = infr.get_edge_data(edge) state = d.get('evidence_decision', UNREV) meta = d.get('meta_decision', NULL) color = truth_colors[state] if state not in {POSTV, NEGTV}: # Darken and saturated same/diff edges without visual # evidence if meta == SAME: color = truth_colors[POSTV] if meta == DIFF: color = truth_colors[NEGTV] # color = util.adjust_hsv_of_rgb( # color, sat_adjust=1, val_adjust=-.3) edges.append(edge) colors.append(color) else: edges = list(graph.edges()) edge_to_weight = nx.get_edge_attributes(graph, 'normscore') weights = np.array( list(ub.dict_take(edge_to_weight, edges, np.nan))) nan_idxs = [] if len(weights) > 0: # give nans threshold value nan_idxs = np.where(np.isnan(weights))[0] thresh = .5 weights[nan_idxs] = thresh colors = infr.get_colored_weights(weights) #print('!! weights = %r' % (len(weights),)) #print('!! edges = %r' % (len(edges),)) #print('!! colors = %r' % (len(colors),)) if len(nan_idxs) > 0: for idx in nan_idxs: colors[idx] = util.Color('gray').as01() return edges, colors
def _assign_confusion_vectors(true_dets, pred_dets, bg_weight=1.0, iou_thresh=0.5, bg_cidx=-1, bias=0.0, classes=None, compat='all', prioritize='iou', ignore_classes='ignore', max_dets=None): """ Create confusion vectors for detections by assigning to ground true boxes Given predictions and truth for an image return (y_pred, y_true, y_score), which is suitable for sklearn classification metrics Args: true_dets (Detections): groundtruth with boxes, classes, and weights pred_dets (Detections): predictions with boxes, classes, and scores iou_thresh (float, default=0.5): bounding box overlap iou threshold required for assignment bias (float, default=0.0): for computing bounding box overlap, either 1 or 0 gids (List[int], default=None): which subset of images ids to compute confusion metrics on. If not specified all images are used. compat (str, default='all'): can be ('ancestors' | 'mutex' | 'all'). determines which pred boxes are allowed to match which true boxes. If 'mutex', then pred boxes can only match true boxes of the same class. If 'ancestors', then pred boxes can match true boxes that match or have a coarser label. If 'all', then any pred can match any true, regardless of its category label. prioritize (str, default='iou'): can be ('iou' | 'class' | 'correct') determines which box to assign to if mutiple true boxes overlap a predicted box. if prioritize is iou, then the true box with maximum iou (above iou_thresh) will be chosen. If prioritize is class, then it will prefer matching a compatible class above a higher iou. If prioritize is correct, then ancestors of the true class are preferred over descendents of the true class, over unreleated classes. bg_cidx (int, default=-1): The index of the background class. The index used in the truth column when a predicted bounding box does not match any true bounding box. classes (List[str] | kwcoco.CategoryTree): mapping from class indices to class names. Can also contain class heirarchy information. ignore_classes (str | List[str]): class name(s) indicating ignore regions max_dets (int): maximum number of detections to consider TODO: - [ ] This is a bottleneck function. An implementation in C / C++ / Cython would likely improve the overall system. - [ ] Implement crowd truth. Allow multiple predictions to match any truth objet marked as "iscrowd". Returns: dict: with relevant confusion vectors. This keys of this dict can be interpreted as columns of a data frame. The `txs` / `pxs` columns represent the indexes of the true / predicted annotations that were assigned as matching. Additionally each row also contains the true and predicted class index, the predicted score, the true weight and the iou of the true and predicted boxes. A `txs` value of -1 means that the predicted box was not assigned to a true annotation and a `pxs` value of -1 means that the true annotation was not assigne to any predicted annotation. Example: >>> # xdoctest: +REQUIRES(module:pandas) >>> import pandas as pd >>> import kwimage >>> # Given a raw numpy representation construct Detection wrappers >>> true_dets = kwimage.Detections( >>> boxes=kwimage.Boxes(np.array([ >>> [ 0, 0, 10, 10], [10, 0, 20, 10], >>> [10, 0, 20, 10], [20, 0, 30, 10]]), 'tlbr'), >>> weights=np.array([1, 0, .9, 1]), >>> class_idxs=np.array([0, 0, 1, 2])) >>> pred_dets = kwimage.Detections( >>> boxes=kwimage.Boxes(np.array([ >>> [6, 2, 20, 10], [3, 2, 9, 7], >>> [3, 9, 9, 7], [3, 2, 9, 7], >>> [2, 6, 7, 7], [20, 0, 30, 10]]), 'tlbr'), >>> scores=np.array([.5, .5, .5, .5, .5, .5]), >>> class_idxs=np.array([0, 0, 1, 2, 0, 1])) >>> bg_weight = 1.0 >>> compat = 'all' >>> iou_thresh = 0.5 >>> bias = 0.0 >>> import kwcoco >>> classes = kwcoco.CategoryTree.from_mutex(list(range(3))) >>> bg_cidx = -1 >>> y = _assign_confusion_vectors(true_dets, pred_dets, bias=bias, >>> bg_weight=bg_weight, iou_thresh=iou_thresh, >>> compat=compat) >>> y = pd.DataFrame(y) >>> print(y) # xdoc: +IGNORE_WANT pred true score weight iou txs pxs 0 1 2 0.5000 1.0000 1.0000 3 5 1 0 -1 0.5000 1.0000 -1.0000 -1 4 2 2 -1 0.5000 1.0000 -1.0000 -1 3 3 1 -1 0.5000 1.0000 -1.0000 -1 2 4 0 -1 0.5000 1.0000 -1.0000 -1 1 5 0 0 0.5000 0.0000 0.6061 1 0 6 -1 0 0.0000 1.0000 -1.0000 0 -1 7 -1 1 0.0000 0.9000 -1.0000 2 -1 Ignore: from xinspect.dynamic_kwargs import get_func_kwargs globals().update(get_func_kwargs(_assign_confusion_vectors)) Example: >>> # xdoctest: +REQUIRES(module:pandas) >>> import pandas as pd >>> from kwcoco.metrics import DetectionMetrics >>> dmet = DetectionMetrics.demo(nimgs=1, nclasses=8, >>> nboxes=(0, 20), n_fp=20, >>> box_noise=.2, cls_noise=.3) >>> classes = dmet.classes >>> gid = 0 >>> true_dets = dmet.true_detections(gid) >>> pred_dets = dmet.pred_detections(gid) >>> y = _assign_confusion_vectors(true_dets, pred_dets, >>> classes=dmet.classes, >>> compat='all', prioritize='class') >>> y = pd.DataFrame(y) >>> print(y) # xdoc: +IGNORE_WANT >>> y = _assign_confusion_vectors(true_dets, pred_dets, >>> classes=dmet.classes, >>> compat='ancestors', iou_thresh=.5) >>> y = pd.DataFrame(y) >>> print(y) # xdoc: +IGNORE_WANT """ import kwarray valid_compat_keys = {'ancestors', 'mutex', 'all'} if compat not in valid_compat_keys: raise KeyError(compat) if classes is None and compat == 'ancestors': compat = 'mutex' if compat == 'mutex': prioritize = 'iou' # Group true boxes by class # Keep track which true boxes are unused / not assigned unique_tcxs, tgroupxs = kwarray.group_indices(true_dets.class_idxs) cx_to_txs = dict(zip(unique_tcxs, tgroupxs)) unique_pcxs = np.array(sorted(set(pred_dets.class_idxs))) if classes is None: import kwcoco # Build mutually exclusive category tree all_cxs = sorted( set(map(int, unique_pcxs)) | set(map(int, unique_tcxs))) all_cxs = list(range(max(all_cxs) + 1)) classes = kwcoco.CategoryTree.from_mutex(all_cxs) cx_to_ancestors = classes.idx_to_ancestor_idxs() if prioritize == 'iou': pdist_priority = None # TODO: cleanup else: pdist_priority = _fast_pdist_priority(classes, prioritize) if compat == 'mutex': # assume classes are mutually exclusive if hierarchy is not given cx_to_matchable_cxs = {cx: [cx] for cx in unique_pcxs} elif compat == 'ancestors': cx_to_matchable_cxs = { cx: sorted([cx] + sorted( ub.take(classes.node_to_idx, nx.ancestors(classes.graph, classes.idx_to_node[cx])))) for cx in unique_pcxs } elif compat == 'all': cx_to_matchable_cxs = {cx: unique_tcxs for cx in unique_pcxs} else: raise KeyError(compat) if compat == 'all': # In this case simply run the full pairwise iou common_true_idxs = np.arange(len(true_dets)) cx_to_matchable_txs = {cx: common_true_idxs for cx in unique_pcxs} common_ious = pred_dets.boxes.ious(true_dets.boxes, bias=bias) # common_ious = pred_dets.boxes.ious(true_dets.boxes, impl='c', bias=bias) iou_lookup = dict(enumerate(common_ious)) else: # For each pred-category find matchable true-indices cx_to_matchable_txs = {} for cx, compat_cx in cx_to_matchable_cxs.items(): matchable_cxs = cx_to_matchable_cxs[cx] compat_txs = ub.dict_take(cx_to_txs, matchable_cxs, default=[]) compat_txs = np.array(sorted(ub.flatten(compat_txs)), dtype=int) cx_to_matchable_txs[cx] = compat_txs # Batch up the IOU pre-computation between compatible truths / preds iou_lookup = {} unique_pred_cxs, pgroupxs = kwarray.group_indices(pred_dets.class_idxs) for cx, pred_idxs in zip(unique_pred_cxs, pgroupxs): true_idxs = cx_to_matchable_txs[cx] ious = pred_dets.boxes[pred_idxs].ious(true_dets.boxes[true_idxs], bias=bias) _px_to_iou = dict(zip(pred_idxs, ious)) iou_lookup.update(_px_to_iou) iou_thresh_list = ([iou_thresh] if not ub.iterable(iou_thresh) else iou_thresh) iou_thresh_to_y = {} for iou_thresh_ in iou_thresh_list: isvalid_lookup = { px: ious > iou_thresh_ for px, ious in iou_lookup.items() } y = _critical_loop(true_dets, pred_dets, iou_lookup, isvalid_lookup, cx_to_matchable_txs, bg_weight, prioritize, iou_thresh_, pdist_priority, cx_to_ancestors, bg_cidx, ignore_classes=ignore_classes, max_dets=max_dets) iou_thresh_to_y[iou_thresh_] = y if ub.iterable(iou_thresh): return iou_thresh_to_y else: return y
def multi_plot(xdata=None, ydata=None, xydata=None, **kwargs): r""" plots multiple lines, bars, etc... One function call that concisely describes the all of the most commonly used parameters needed when plotting a bar / line char. This is especially useful when multiple plots are needed in the same domain. Args: xdata (List[ndarray] | Dict[str, ndarray] | ndarray): x-coordinate data common to all y-coordinate values or xdata for each line/bar in ydata. Mutually exclusive with xydata. ydata (List[ndarray] | Dict[str, ndarray] | ndarary): y-coordinate values for each line/bar to plot. Can also be just a single ndarray of scalar values. Mutually exclusive with xydata. xydata (Dict[str, Tuple[ndarray, ndarray]]): mapping from labels to a tuple of xdata and ydata for a each line. **kwargs: fnum (int): figure number to draw on pnum (Tuple[int, int, int]): plot number to draw on within the figure: e.g. (1, 1, 1) label (List|Dict): if you specified ydata as a List[ndarray] this is the label for each line in that list. Note this is unnecessary if you specify input as a dictionary mapping labels to lines. color (str|List|Dict): either a special color code, a single color, or a color for each item in ydata. In the later case, this should be specified as either a list or a dict depending on how ydata was specified. marker (str|List|Dict): type of matplotlib marker to use at every data point. Can be specified for all lines jointly or for each line independently. transpose (bool, default=False): swaps x and y data. kind (str, default='plot'): The kind of plot. Can either be 'plot' or 'bar'. We parse these other kwargs if: if kind='plot': spread if kind='bar': stacked, width Misc: use_legend (bool): ... legend_loc (str): one of 'best', 'upper right', 'upper left', 'lower left', 'lower right', 'right', 'center left', 'center right', 'lower center', or 'upper center'. Layout: xlabel (str): label for x-axis ylabel (str): label for y-axis title (str): title for the axes figtitle (str): title for the figure xscale (str): can be one of [linear, log, logit, symlog] yscale (str): can be one of [linear, log, logit, symlog] xlim (Tuple[float, float]): low and high x-limit of axes ylim (Tuple[float, float]): low and high y-limit of axes xmin (float): low x-limit of axes, mutex with xlim xmax (float): high x-limit of axes, mutex with xlim ymin (float): low y-limit of axes, mutex with ylim ymax (float): high y-limit of axes, mutex with ylim titlesize (float): ... legendsize (float): ... labelsize (float): ... Grid: gridlinewidth (float): ... gridlinestyle (str): ... Ticks: num_xticks (int): number of x ticks num_yticks (int): number of y ticks tickwidth (float): ... ticklength (float): ... ticksize (float): ... xticklabels (list): list of x-tick labels, overrides num_xticks yticklabels (list): list of y-tick labels, overrides num_yticks xtick_rotation (float): xtick rotation in degrees ytick_rotation (float): ytick rotation in degrees Data: spread (List | Dict): Plots a spread around plot lines usually indicating standard deviation markersize (float|List|Dict): marker size for all or each plot markeredgewidth (float|List|Dict): marker edge width for all or each plot linewidth (float|List|Dict): line width for all or each plot linestyle (str|List|Dict): line style for all or each plot Notes: any plot_kw key can be a scalar (corresponding to all ydatas), a list if ydata was specified as a list, or a dict if ydata was specified as a dict. plot_kw_keys = ['label', 'color', 'marker', 'markersize', 'markeredgewidth', 'linewidth', 'linestyle'] Returns: matplotlib.axes.Axes: ax : the axes that was drawn on References: matplotlib.org/examples/api/barchart_demo.html Example: >>> import kwplot >>> kwplot.autompl() >>> # The new way to use multi_plot is to pass ydata as a dict of lists >>> ydata = { >>> 'spamΣ': [1, 1, 2, 3, 5, 8, 13], >>> 'eggs': [3, 3, 3, 3, 3, np.nan, np.nan], >>> 'jamµ': [5, 3, np.nan, 1, 2, np.nan, np.nan], >>> 'pram': [4, 2, np.nan, 0, 0, np.nan, 1], >>> } >>> ax = kwplot.multi_plot(ydata=ydata, title='ΣΣΣµµµ', >>> xlabel='\nfdsΣΣΣµµµ', linestyle='--') >>> kwplot.show_if_requested() Example: >>> # Old way to use multi_plot is a list of lists >>> import kwplot >>> kwplot.autompl() >>> xdata = [1, 2, 3, 4, 5] >>> ydata_list = [[1, 2, 3, 4, 5], [3, 3, 3, 3, 3], [5, 4, np.nan, 2, 1], [4, 3, np.nan, 1, 0]] >>> kwargs = {'label': ['spamΣ', 'eggs', 'jamµ', 'pram'], 'linestyle': '-'} >>> #ax = multi_plot(xdata, ydata_list, title='$\phi_1(\\vec{x})$', xlabel='\nfds', **kwargs) >>> ax = multi_plot(xdata, ydata_list, title='ΣΣΣµµµ', xlabel='\nfdsΣΣΣµµµ', **kwargs) >>> kwplot.show_if_requested() Example: >>> # Simple way to use multi_plot is to pass xdata and ydata exactly >>> # like you would use plt.plot >>> import kwplot >>> kwplot.autompl() >>> ax = multi_plot([1, 2, 3], [4, 5, 6], fnum=4, label='foo') >>> kwplot.show_if_requested() Example: >>> import kwplot >>> kwplot.autompl() >>> xydata = {'a': ([0, 1, 2], [0, 1, 2]), 'b': ([0, 2, 4], [2, 1, 0])} >>> ax = kwplot.multi_plot(xydata=xydata, fnum=4) >>> kwplot.show_if_requested() Example: >>> import kwplot >>> kwplot.autompl() >>> ydata = {'a': [0, 1, 2], 'b': [1, 2, 1], 'c': [4, 4, 4, 3, 2]} >>> kwargs = { >>> 'spread': {'a': [.2, .3, .1], 'b': .2}, >>> 'xlim': (-1, 5), >>> 'xticklabels': ['foo', 'bar'], >>> 'xtick_rotation': 90, >>> } >>> ax = kwplot.multi_plot(ydata=ydata, fnum=4, **kwargs) >>> kwplot.show_if_requested() Ignore: >>> import kwplot >>> kwplot.autompl() >>> ydata = { >>> str(i): np.random.rand(100) + i for i in range(30) >>> } >>> ax = kwplot.multi_plot(ydata=ydata, fnum=1, doclf=True) >>> kwplot.show_if_requested() """ import matplotlib as mpl from matplotlib import pyplot as plt # Initial integration with mpl rcParams standards mplrc = mpl.rcParams # mplrc.update({ # # 'legend.fontsize': custom_figure.LEGEND_SIZE, # # 'legend.framealpha': # # 'axes.titlesize': custom_figure.TITLE_SIZE, # # 'axes.labelsize': custom_figure.LABEL_SIZE, # # 'legend.facecolor': 'w', # # 'font.family': 'sans-serif', # # 'xtick.labelsize': custom_figure.TICK_SIZE, # # 'ytick.labelsize': custom_figure.TICK_SIZE, # }) if 'rcParams' in kwargs: mplrc = mplrc.copy() mplrc.update(kwargs['rcParams']) if xydata is not None: if xdata is not None or ydata is not None: raise ValueError('Cannot specify xydata with xdata or ydata') if isinstance(xydata, dict): xdata = ub.odict((k, np.array(xy[0])) for k, xy in xydata.items()) ydata = ub.odict((k, np.array(xy[1])) for k, xy in xydata.items()) else: raise ValueError('Only supports xydata as Dict at the moment') if bool('label' in kwargs) and bool('label_list' in kwargs): raise ValueError('Specify either label or label_list') if isinstance(ydata, dict): # Case where ydata is a dictionary if isinstance(xdata, six.string_types): # Special-er case where xdata is specified in ydata xkey = xdata ykeys = set(ydata.keys()) - {xkey} xdata = ydata[xkey] else: ykeys = list(ydata.keys()) # Normalize input into ydata_list ydata_list = list(ub.take(ydata, ykeys)) default_label_list = kwargs.pop('label', ykeys) kwargs['label_list'] = kwargs.get('label_list', default_label_list) else: # ydata should be a List[ndarray] or an ndarray ydata_list = ydata ykeys = None # allow ydata_list to be passed without a container if is_list_of_scalars(ydata_list): ydata_list = [np.array(ydata_list)] if xdata is None: xdata = list(range(max(map(len, ydata_list)))) num_lines = len(ydata_list) # Transform xdata into xdata_list if isinstance(xdata, dict): xdata_list = [np.array(xdata[k], copy=True) for k in ykeys] elif is_list_of_lists(xdata): xdata_list = [np.array(xd, copy=True) for xd in xdata] else: xdata_list = [np.array(xdata, copy=True)] * num_lines fnum = mpl_core.ensure_fnum(kwargs.get('fnum', None)) pnum = kwargs.get('pnum', None) kind = kwargs.get('kind', 'plot') transpose = kwargs.get('transpose', False) def parsekw_list(key, kwargs, num_lines=num_lines, ykeys=ykeys, default=ub.NoParam): """ Return properties that corresponds with ydata_list. Searches kwargs for several keys based on the base key and finds either a scalar, list, or dict and coerces this into a list of properties that corresonds with the ydata_list. """ if key in kwargs: val_list = kwargs[key] elif key + '_list' in kwargs: # warnings.warn('*_list is depricated, just use kwarg {}'.format(key)) val_list = kwargs[key + '_list'] elif key + 's' in kwargs: # hack, multiple ways to do something warnings.warn('*s depricated, just use kwarg {}'.format(key)) val_list = kwargs[key + 's'] else: val_list = None if val_list is not None: if isinstance(val_list, dict): # Extract propertly ordered dictionary values if ykeys is None: raise ValueError( 'Kwarg {!r} was a dict, but ydata was not'.format(key)) else: if default is ub.NoParam: val_list = [val_list[key] for key in ykeys] else: val_list = [ val_list.get(key, default) for key in ykeys ] if not isinstance(val_list, list): # Coerce a scalar value into a list val_list = [val_list] * num_lines return val_list if kind == 'plot': if 'marker' not in kwargs: # kwargs['marker'] = mplrc['lines.marker'] kwargs['marker'] = 'distinct' # kwargs['marker'] = 'cycle' if isinstance(kwargs['marker'], six.string_types): if kwargs['marker'] == 'distinct': kwargs['marker'] = mpl_core.distinct_markers(num_lines) elif kwargs['marker'] == 'cycle': # Note the length of marker and linestyle cycles should be # relatively prime. # https://matplotlib.org/api/markers_api.html marker_cycle = ['.', '*', 'x'] kwargs['marker'] = [ marker_cycle[i % len(marker_cycle)] for i in range(num_lines) ] # else: # raise KeyError(kwargs['marker']) if 'linestyle' not in kwargs: # kwargs['linestyle'] = 'distinct' kwargs['linestyle'] = mplrc['lines.linestyle'] # kwargs['linestyle'] = 'cycle' if isinstance(kwargs['linestyle'], six.string_types): if kwargs['linestyle'] == 'cycle': # https://matplotlib.org/gallery/lines_bars_and_markers/line_styles_reference.html linestyle_cycle = ['solid', 'dashed', 'dashdot', 'dotted'] kwargs['linestyle'] = [ linestyle_cycle[i % len(linestyle_cycle)] for i in range(num_lines) ] if 'color' not in kwargs: # kwargs['color'] = 'jet' # kwargs['color'] = 'gist_rainbow' kwargs['color'] = 'distinct' if isinstance(kwargs['color'], six.string_types): if kwargs['color'] == 'distinct': kwargs['color'] = mpl_core.distinct_colors(num_lines, randomize=0) else: cm = plt.get_cmap(kwargs['color']) kwargs['color'] = [cm(i / num_lines) for i in range(num_lines)] # Parse out arguments to ax.plot plot_kw_keys = [ 'label', 'color', 'marker', 'markersize', 'markeredgewidth', 'linewidth', 'linestyle', 'alpha' ] # hackish / extra args that dont directly get passed to plt.plot extra_plot_kw_keys = ['spread_alpha', 'autolabel', 'edgecolor', 'fill'] plot_kw_keys += extra_plot_kw_keys plot_ks_vals = [parsekw_list(key, kwargs) for key in plot_kw_keys] plot_list_kw = dict([(key, vals) for key, vals in zip(plot_kw_keys, plot_ks_vals) if vals is not None]) if kind == 'plot': if 'spread_alpha' not in plot_list_kw: plot_list_kw['spread_alpha'] = [.2] * num_lines if kind == 'bar': # Remove non-bar kwargs for key in [ 'markeredgewidth', 'linewidth', 'marker', 'markersize', 'linestyle' ]: plot_list_kw.pop(key, None) stacked = kwargs.get('stacked', False) width_key = 'height' if transpose else 'width' if 'width_list' in kwargs: plot_list_kw[width_key] = kwargs['width_list'] else: width = kwargs.get('width', .9) # if width is None: # # HACK: need variable width # # width = np.mean(np.diff(xdata_list[0])) # width = .9 if not stacked: width /= num_lines #plot_list_kw['orientation'] = ['horizontal'] * num_lines plot_list_kw[width_key] = [width] * num_lines spread_list = parsekw_list('spread', kwargs, default=None) # nest into a list of dicts for each line in the multiplot valid_keys = list(set(plot_list_kw.keys()) - set(extra_plot_kw_keys)) valid_vals = list(ub.dict_take(plot_list_kw, valid_keys)) plot_kw_list = [dict(zip(valid_keys, vals)) for vals in zip(*valid_vals)] extra_kw_keys = [key for key in extra_plot_kw_keys if key in plot_list_kw] extra_kw_vals = list(ub.dict_take(plot_list_kw, extra_kw_keys)) extra_kw_list = [ dict(zip(extra_kw_keys, vals)) for vals in zip(*extra_kw_vals) ] # Get passed in axes or setup a new figure ax = kwargs.get('ax', None) if ax is None: # NOTE: This is slow, can we speed it up? doclf = kwargs.get('doclf', False) fig = mpl_core.figure(fnum=fnum, pnum=pnum, docla=False, doclf=doclf) ax = fig.gca() else: plt.sca(ax) fig = ax.figure # +--------------- # Draw plot lines ydata_list = [np.array(ydata) for ydata in ydata_list] if transpose: if kind == 'bar': plot_func = ax.barh elif kind == 'plot': def plot_func(_x, _y, **kw): return ax.plot(_y, _x, **kw) else: plot_func = getattr(ax, kind) # usually ax.plot if len(ydata_list) > 0: # raise ValueError('no ydata') _iter = enumerate( zip_longest(xdata_list, ydata_list, plot_kw_list, extra_kw_list)) for count, (_xdata, _ydata, plot_kw, extra_kw) in _iter: _ydata = _ydata[0:len(_xdata)] _xdata = _xdata[0:len(_ydata)] ymask = np.isfinite(_ydata) ydata_ = _ydata.compress(ymask) xdata_ = _xdata.compress(ymask) if kind == 'bar': if stacked: # Plot bars on top of each other xdata_ = xdata_ else: # Plot bars side by side baseoffset = (width * num_lines) / 2 lineoffset = (width * count) offset = baseoffset - lineoffset # Fixeme for more histogram bars xdata_ = xdata_ - offset # width_key = 'height' if transpose else 'width' # plot_kw[width_key] = np.diff(xdata) objs = plot_func(xdata_, ydata_, **plot_kw) if kind == 'bar': if extra_kw is not None and 'edgecolor' in extra_kw: for rect in objs: rect.set_edgecolor(extra_kw['edgecolor']) if extra_kw is not None and extra_kw.get('autolabel', False): # FIXME: probably a more cannonical way to include bar # autolabeling with tranpose support, but this is a hack that # works for now for rect in objs: if transpose: numlbl = width = rect.get_width() xpos = width + ( (_xdata.max() - _xdata.min()) * .005) ypos = rect.get_y() + rect.get_height() / 2. ha, va = 'left', 'center' else: numlbl = height = rect.get_height() xpos = rect.get_x() + rect.get_width() / 2. ypos = 1.05 * height ha, va = 'center', 'bottom' barlbl = '%.3f' % (numlbl, ) ax.text(xpos, ypos, barlbl, ha=ha, va=va) if kind == 'plot' and extra_kw.get('fill', False): ax.fill_between(_xdata, ydata_, alpha=plot_kw.get('alpha', 1.0), color=plot_kw.get('color', None)) # , zorder=0) if spread_list is not None: # Plots a spread around plot lines usually indicating standard # deviation _xdata = np.array(_xdata) _spread = spread_list[count] if _spread is not None: if not ub.iterable(_spread): _spread = [_spread] * len(ydata_) ydata_ave = np.array(ydata_) y_data_dev = np.array(_spread) y_data_max = ydata_ave + y_data_dev y_data_min = ydata_ave - y_data_dev ax = plt.gca() spread_alpha = extra_kw['spread_alpha'] ax.fill_between(_xdata, y_data_min, y_data_max, alpha=spread_alpha, color=plot_kw.get('color', None)) # , zorder=0) ydata = _ydata # HACK xdata = _xdata # HACK # L________________ #max_y = max(np.max(y_data), max_y) #min_y = np.min(y_data) if min_y is None else min(np.min(y_data), min_y) if transpose: #xdata_list = ydata_list ydata = xdata # Hack / Fix any transpose issues def transpose_key(key): if key.startswith('x'): return 'y' + key[1:] elif key.startswith('y'): return 'x' + key[1:] elif key.startswith('num_x'): # hackier, fixme to use regex or something return 'num_y' + key[5:] elif key.startswith('num_y'): # hackier, fixme to use regex or something return 'num_x' + key[5:] else: return key kwargs = {transpose_key(key): val for key, val in kwargs.items()} # Setup axes labeling title = kwargs.get('title', None) xlabel = kwargs.get('xlabel', '') ylabel = kwargs.get('ylabel', '') def none_or_unicode(text): return None if text is None else ub.ensure_unicode(text) xlabel = none_or_unicode(xlabel) ylabel = none_or_unicode(ylabel) title = none_or_unicode(title) titlesize = kwargs.get('titlesize', mplrc['axes.titlesize']) labelsize = kwargs.get('labelsize', mplrc['axes.labelsize']) legendsize = kwargs.get('legendsize', mplrc['legend.fontsize']) xticksize = kwargs.get('ticksize', mplrc['xtick.labelsize']) yticksize = kwargs.get('ticksize', mplrc['ytick.labelsize']) family = kwargs.get('fontfamily', mplrc['font.family']) tickformat = kwargs.get('tickformat', None) ytickformat = kwargs.get('ytickformat', tickformat) xtickformat = kwargs.get('xtickformat', tickformat) # 'DejaVu Sans','Verdana', 'Arial' weight = kwargs.get('fontweight', None) if weight is None: weight = 'normal' labelkw = { 'fontproperties': mpl.font_manager.FontProperties(weight=weight, family=family, size=labelsize) } ax.set_xlabel(xlabel, **labelkw) ax.set_ylabel(ylabel, **labelkw) tick_fontprop = mpl.font_manager.FontProperties(family=family, weight=weight) if tick_fontprop is not None: # NOTE: This is slow, can we speed it up? for ticklabel in ax.get_xticklabels(): ticklabel.set_fontproperties(tick_fontprop) for ticklabel in ax.get_yticklabels(): ticklabel.set_fontproperties(tick_fontprop) if xticksize is not None: for ticklabel in ax.get_xticklabels(): ticklabel.set_fontsize(xticksize) if yticksize is not None: for ticklabel in ax.get_yticklabels(): ticklabel.set_fontsize(yticksize) if xtickformat is not None: # mpl.ticker.StrMethodFormatter # new style # mpl.ticker.FormatStrFormatter # old style ax.xaxis.set_major_formatter( mpl.ticker.FormatStrFormatter(xtickformat)) if ytickformat is not None: ax.yaxis.set_major_formatter( mpl.ticker.FormatStrFormatter(ytickformat)) xtick_kw = ytick_kw = { 'width': kwargs.get('tickwidth', None), 'length': kwargs.get('ticklength', None), } xtick_kw = {k: v for k, v in xtick_kw.items() if v is not None} ytick_kw = {k: v for k, v in ytick_kw.items() if v is not None} ax.xaxis.set_tick_params(**xtick_kw) ax.yaxis.set_tick_params(**ytick_kw) #ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%d')) # Setup axes limits if 'xlim' in kwargs: xlim = kwargs['xlim'] if xlim is not None: if 'xmin' not in kwargs and 'xmax' not in kwargs: kwargs['xmin'] = xlim[0] kwargs['xmax'] = xlim[1] else: raise ValueError('use xmax, xmin instead of xlim') if 'ylim' in kwargs: ylim = kwargs['ylim'] if ylim is not None: if 'ymin' not in kwargs and 'ymax' not in kwargs: kwargs['ymin'] = ylim[0] kwargs['ymax'] = ylim[1] else: raise ValueError('use ymax, ymin instead of ylim') xmin = kwargs.get('xmin', ax.get_xlim()[0]) xmax = kwargs.get('xmax', ax.get_xlim()[1]) ymin = kwargs.get('ymin', ax.get_ylim()[0]) ymax = kwargs.get('ymax', ax.get_ylim()[1]) text_type = six.text_type if text_type(xmax) == 'data': xmax = max([xd.max() for xd in xdata_list]) if text_type(xmin) == 'data': xmin = min([xd.min() for xd in xdata_list]) # Setup axes ticks num_xticks = kwargs.get('num_xticks', None) num_yticks = kwargs.get('num_yticks', None) if num_xticks is not None: if xdata.dtype.kind == 'i': xticks = np.linspace(np.ceil(xmin), np.floor(xmax), num_xticks).astype(np.int32) else: xticks = np.linspace((xmin), (xmax), num_xticks) ax.set_xticks(xticks) if num_yticks is not None: if ydata.dtype.kind == 'i': yticks = np.linspace(np.ceil(ymin), np.floor(ymax), num_yticks).astype(np.int32) else: yticks = np.linspace((ymin), (ymax), num_yticks) ax.set_yticks(yticks) force_xticks = kwargs.get('force_xticks', None) if force_xticks is not None: xticks = np.array(sorted(ax.get_xticks().tolist() + force_xticks)) ax.set_xticks(xticks) yticklabels = kwargs.get('yticklabels', None) if yticklabels is not None: # Hack ONLY WORKS WHEN TRANSPOSE = True # Overrides num_yticks missing_labels = max(len(ydata) - len(yticklabels), 0) yticklabels_ = yticklabels + [''] * missing_labels ax.set_yticks(ydata) ax.set_yticklabels(yticklabels_) xticklabels = kwargs.get('xticklabels', None) if xticklabels is not None: # Overrides num_xticks missing_labels = max(len(xdata) - len(xticklabels), 0) xticklabels_ = xticklabels + [''] * missing_labels ax.set_xticks(xdata) ax.set_xticklabels(xticklabels_) xticks = kwargs.get('xticks', None) if xticks is not None: ax.set_xticks(xticks) yticks = kwargs.get('yticks', None) if yticks is not None: ax.set_yticks(yticks) xtick_rotation = kwargs.get('xtick_rotation', None) if xtick_rotation is not None: [lbl.set_rotation(xtick_rotation) for lbl in ax.get_xticklabels()] ytick_rotation = kwargs.get('ytick_rotation', None) if ytick_rotation is not None: [lbl.set_rotation(ytick_rotation) for lbl in ax.get_yticklabels()] # Axis padding xpad = kwargs.get('xpad', None) ypad = kwargs.get('ypad', None) xpad_factor = kwargs.get('xpad_factor', None) ypad_factor = kwargs.get('ypad_factor', None) if xpad is None and xpad_factor is not None: xpad = (xmax - xmin) * xpad_factor if ypad is None and ypad_factor is not None: ypad = (ymax - ymin) * ypad_factor xpad = 0 if xpad is None else xpad ypad = 0 if ypad is None else ypad ypad_high = kwargs.get('ypad_high', ypad) ypad_low = kwargs.get('ypad_low', ypad) xpad_high = kwargs.get('xpad_high', xpad) xpad_low = kwargs.get('xpad_low', xpad) xmin, xmax = (xmin - xpad_low), (xmax + xpad_high) ymin, ymax = (ymin - ypad_low), (ymax + ypad_high) ax.set_xlim(xmin, xmax) ax.set_ylim(ymin, ymax) xscale = kwargs.get('xscale', None) yscale = kwargs.get('yscale', None) if yscale is not None: ax.set_yscale(yscale) if xscale is not None: ax.set_xscale(xscale) gridlinestyle = kwargs.get('gridlinestyle', None) gridlinewidth = kwargs.get('gridlinewidth', None) gridlines = ax.get_xgridlines() + ax.get_ygridlines() if gridlinestyle: for line in gridlines: line.set_linestyle(gridlinestyle) if gridlinewidth: for line in gridlines: line.set_linewidth(gridlinewidth) # Setup title if title is not None: titlekw = { 'fontproperties': mpl.font_manager.FontProperties(family=family, weight=weight, size=titlesize) } ax.set_title(title, **titlekw) use_legend = kwargs.get('use_legend', 'label' in valid_keys) legend_loc = kwargs.get('legend_loc', mplrc['legend.loc']) legend_alpha = kwargs.get('legend_alpha', mplrc['legend.framealpha']) if use_legend: legendkw = { 'alpha': legend_alpha, 'fontproperties': mpl.font_manager.FontProperties(family=family, weight=weight, size=legendsize) } mpl_core.legend(loc=legend_loc, ax=ax, **legendkw) figtitle = kwargs.get('figtitle', None) if figtitle is not None: # mplrc['figure.titlesize'] TODO? mpl_core.set_figtitle(figtitle, fontfamily=family, fontweight=weight, size=kwargs.get('figtitlesize')) # TODO: return better info return ax
def test_dict_take(): # There was a bug in 0.7.0 where iterable keys would be exhausted too soon keys_list = list(range(10)) dict_ = {k: k for k in keys_list} got = list(ub.dict_take(dict_, keys_list)) assert got == keys_list
def 字典_取值(字典, key, default=None): return list(ub.dict_take(字典, key, default))
def make_agraph(graph_, inplace=False): import pygraphviz patch_pygraphviz() if not inplace: graph_ = graph_.copy() # Convert to agraph format num_nodes = len(graph_) LARGE_GRAPH = 100 is_large = num_nodes > LARGE_GRAPH if is_large: print('Making agraph for large graph %d nodes. ' 'May take time' % (num_nodes)) # nx_ensure_agraph_color(graph_) # Reduce size to be in inches not pixels # FIXME: make robust to param settings # Hack to make the w/h of the node take thae max instead of # dot which takes the minimum shaped_nodes = [n for n, d in graph_.nodes(data=True) if 'width' in d] node_dict = graph_.nodes node_attrs = ub.dict_take(node_dict, shaped_nodes) width_px = np.array([n['width'] for n in node_attrs]) height_px = np.array([n['height'] for n in node_attrs]) scale = np.array([n.get('scale', 1.0) for n in node_attrs]) inputscale = 72.0 width_in = width_px / inputscale * scale height_in = height_px / inputscale * scale width_in_dict = dict(zip(shaped_nodes, width_in)) height_in_dict = dict(zip(shaped_nodes, height_in)) nx.set_node_attributes(graph_, name='width', values=width_in_dict) nx.set_node_attributes(graph_, name='height', values=height_in_dict) nx_delete_node_attr(graph_, name='scale') # Check for any nodes with groupids node_to_groupid = nx.get_node_attributes(graph_, 'groupid') if node_to_groupid: groupid_to_nodes = ub.group_items(*zip(*node_to_groupid.items())) else: groupid_to_nodes = {} # Initialize agraph format nx_delete_None_edge_attr(graph_) agraph = nx.nx_agraph.to_agraph(graph_) # Add subgraphs labels # TODO: subgraph attrs group_attrs = graph_.graph.get('groupattrs', {}) for groupid, nodes in groupid_to_nodes.items(): # subgraph_attrs = {} subgraph_attrs = group_attrs.get(groupid, {}).copy() cluster_flag = True # FIXME: make this more natural to specify if 'cluster' in subgraph_attrs: cluster_flag = subgraph_attrs['cluster'] del subgraph_attrs['cluster'] name = groupid if cluster_flag: # graphviz treast subgraphs labeld with cluster differently name = 'cluster_' + groupid else: name = groupid agraph.add_subgraph(nodes, name, **subgraph_attrs) for node in graph_.nodes(): anode = pygraphviz.Node(agraph, node) # TODO: Generally fix node positions ptstr_ = anode.attr['pos'] if (ptstr_ is not None and len(ptstr_) > 0 and not ptstr_.endswith('!')): ptstr = ptstr_.strip('[]').strip(' ').strip('()') ptstr_list = [x.rstrip(',') for x in re.split(r'\s+', ptstr)] pt_list = list(map(float, ptstr_list)) pt_arr = np.array(pt_list) / inputscale new_ptstr_list = list(map(str, pt_arr)) new_ptstr_ = ','.join(new_ptstr_list) if anode.attr['pin'] is True: anode.attr['pin'] = 'true' if anode.attr['pin'] == 'true': new_ptstr = new_ptstr_ + '!' else: new_ptstr = new_ptstr_ anode.attr['pos'] = new_ptstr if graph_.graph.get('ignore_labels', False): for node in graph_.nodes(): anode = pygraphviz.Node(agraph, node) if 'label' in anode.attr: try: del anode.attr['label'] except KeyError: pass return agraph