Exemple #1
0
    def scale(self, factor):
        r"""
        works with tlbr, cxywh, xywh, xy, or wh formats

        Example:
            >>> # xdoctest: +IGNORE_WHITESPACE
            >>> Boxes(np.array([1, 1, 10, 10])).scale(2).data
            array([ 2.,  2., 20., 20.])
            >>> Boxes(np.array([[1, 1, 10, 10]])).scale((2, .5)).data
            array([[ 2. ,  0.5, 20. ,  5. ]])
            >>> Boxes(np.array([[10, 10]])).scale(.5).data
            array([[5., 5.]])
        """
        boxes = self.data
        sx, sy = factor if ub.iterable(factor) else (factor, factor)
        if boxes.dtype.kind != 'f':
            new_data = boxes.astype(np.float)
        else:
            new_data = boxes.copy()
        new_data[..., 0:4:2] *= sx
        new_data[..., 1:4:2] *= sy
        return Boxes(new_data, self.format)
Exemple #2
0
    def random(Points, num=1, classes=None, rng=None):
        """
        Makes random points; typically for testing purposes

        Example:
            >>> import kwimage
            >>> self = kwimage.Points.random(classes=[1, 2, 3])
            >>> self.data
            >>> print('self.data = {!r}'.format(self.data))
        """
        rng = kwarray.ensure_rng(rng)
        if ub.iterable(num):
            shape = tuple(num) + (2,)
        else:
            shape = (num, 2)
        self = Points(xy=rng.rand(*shape))
        self.data['visible'] = np.full(len(self), fill_value=2)
        if classes is not None:
            class_idxs = (rng.rand(len(self)) * len(classes)).astype(np.int)
            self.data['class_idxs'] = class_idxs
            self.meta['classes'] = classes
        return self
Exemple #3
0
 def shift(self, amount):
     """
     Example:
         >>> # xdoctest: +IGNORE_WHITESPACE
         >>> Boxes([25, 30, 15, 10], 'xywh').shift(10)
         <Boxes(xywh, array([35., 40., 15., 10.]))>
         >>> Boxes([25, 30, 15, 10], 'xywh').shift((10, 0))
         <Boxes(xywh, array([35., 30., 15., 10.]))>
         >>> Boxes([25, 30, 15, 10], 'tlbr').shift((10, 5))
         <Boxes(tlbr, array([35., 35., 25., 15.]))>
     """
     boxes = self.data
     tx, ty = amount if ub.iterable(amount) else (amount, amount)
     new_data = boxes.astype(np.float).copy()
     if self.format in ['xywh', 'cxywh']:
         new_data[..., 0] += tx
         new_data[..., 1] += ty
     elif self.format in ['tlbr']:
         new_data[..., 0:4:2] += tx
         new_data[..., 1:4:2] += ty
     else:
         raise KeyError(self.format)
     return Boxes(new_data, self.format)
Exemple #4
0
    def coerce(CategoryPatterns, data=None, **kwargs):
        """
        Construct category patterns from either defaults or only with specific
        categories. Can accept either an existig category pattern object, a
        list of known catnames, or mscoco category dictionaries.

        Example:
            >>> data = ['superstar']
            >>> self = CategoryPatterns.coerce(data)
        """
        if isinstance(data, CategoryPatterns):
            return data
        else:
            if data is None:
                # use defaults
                catnames = CategoryPatterns._default_catnames
                cname_to_cat = {
                    c['name']: c
                    for c in CategoryPatterns._default_categories
                }
                arg = list(ub.take(cname_to_cat, catnames))
            elif ub.iterable(data) and len(data) > 0:
                # choose specific catgories
                if isinstance(data[0], six.string_types):
                    catnames = data
                    cname_to_cat = {
                        c['name']: c
                        for c in CategoryPatterns._default_categories
                    }
                    arg = list(ub.take(cname_to_cat, catnames))
                elif isinstance(data[0], dict):
                    arg = data
                else:
                    raise Exception
            else:
                raise Exception
            return CategoryPatterns(categories=arg, **kwargs)
Exemple #5
0
    def lookup(self, key, default=ub.NoParam, keepid=False):
        """
        Lookup a list of object attributes

        Args:
            key (str | Iterable): name of the property you want to lookup
                can also be a list of names, in which case we return a dict

            default : if specified, uses this value if it doesn't exist
                in an ObjT.

            keepid: if True, return a mapping from ids to the property

        Returns:
            List[ObjT]: a list of whatever type the object is
            Dict[str, ObjT]

        Example:
            >>> import kwcoco
            >>> dset = kwcoco.CocoDataset.demo()
            >>> self = dset.annots()
            >>> self.lookup('id')
            >>> key = ['id']
            >>> default = None
            >>> self.lookup(key=['id', 'image_id'])
            >>> self.lookup(key=['id', 'image_id'])
            >>> self.lookup(key='foo', default=None, keepid=True)
            >>> self.lookup(key=['foo'], default=None, keepid=True)
            >>> self.lookup(key=['id', 'image_id'], keepid=True)
        """
        # Note: while the old _lookup code was slightly faster than this, the
        # difference is extremely negligable (179us vs 178us).
        if ub.iterable(key):
            return {k: self.lookup(k, default, keepid) for k in key}
        else:
            return self.get(key, default=default, keepid=keepid)
Exemple #6
0
def multi_plot(xdata=None, ydata=None, xydata=None, **kwargs):
    r"""
    plots multiple lines, bars, etc...

    One function call that concisely describes the all of the most commonly
    used parameters needed when plotting a bar / line char. This is especially
    useful when multiple plots are needed in the same domain.

    Args:
        xdata (List[ndarray] | Dict[str, ndarray] | ndarray):
            x-coordinate data common to all y-coordinate values or xdata for
            each line/bar in ydata.  Mutually exclusive with xydata.

        ydata (List[ndarray] | Dict[str, ndarray] | ndarary):
            y-coordinate values for each line/bar to plot. Can also
            be just a single ndarray of scalar values. Mutually exclusive with
            xydata.

        xydata (Dict[str, Tuple[ndarray, ndarray]]):
            mapping from labels to a tuple of xdata and ydata for a each line.

        **kwargs:

            fnum (int):
                figure number to draw on

            pnum (Tuple[int, int, int]):
                plot number to draw on within the figure: e.g. (1, 1, 1)

            label (List|Dict): if you specified ydata as a List[ndarray]
                this is the label for each line in that list. Note this is
                unnecessary if you specify input as a dictionary mapping labels
                to lines.

            color (str|List|Dict): either a special color code, a single color,
                or a color for each item in ydata. In the later case, this
                should be specified as either a list or a dict depending on how
                ydata was specified.

            marker (str|List|Dict): type of matplotlib marker to use at every
                data point. Can be specified for all lines jointly or for each
                line independently.

            transpose (bool, default=False): swaps x and y data.

            kind (str, default='plot'):
                The kind of plot. Can either be 'plot' or 'bar'.
                We parse these other kwargs if:
                    if kind='plot':
                        spread
                    if kind='bar':
                        stacked, width

            Misc:
                use_legend (bool): ...
                legend_loc (str):
                    one of 'best', 'upper right', 'upper left', 'lower left',
                    'lower right', 'right', 'center left', 'center right',
                    'lower center', or 'upper center'.

            Layout:
                xlabel (str): label for x-axis
                ylabel (str): label for y-axis
                title (str): title for the axes
                figtitle (str): title for the figure

                xscale (str): can be one of [linear, log, logit, symlog]
                yscale (str): can be one of [linear, log, logit, symlog]

                xlim (Tuple[float, float]): low and high x-limit of axes
                ylim (Tuple[float, float]): low and high y-limit of axes
                xmin (float): low x-limit of axes, mutex with xlim
                xmax (float): high x-limit of axes, mutex with xlim
                ymin (float): low y-limit of axes, mutex with ylim
                ymax (float): high y-limit of axes, mutex with ylim

                titlesize (float): ...
                legendsize (float): ...
                labelsize (float): ...

            Grid:
                gridlinewidth (float): ...
                gridlinestyle (str): ...
            Ticks:
                num_xticks (int): number of x ticks
                num_yticks (int): number of y ticks
                tickwidth (float): ...
                ticklength (float): ...
                ticksize (float): ...
                xticklabels (list): list of x-tick labels, overrides num_xticks
                yticklabels (list): list of y-tick labels, overrides num_yticks
                xtick_rotation (float): xtick rotation in degrees
                ytick_rotation (float): ytick rotation in degrees
            Data:
                spread (List | Dict): Plots a spread around plot lines usually
                    indicating standard deviation

                markersize (float|List|Dict): marker size for all or each plot

                markeredgewidth (float|List|Dict): marker edge width for all or each plot

                linewidth (float|List|Dict): line width for all or each plot

                linestyle (str|List|Dict): line style for all or each plot

    Notes:
        any plot_kw key can be a scalar (corresponding to all ydatas),
        a list if ydata was specified as a list, or a dict if ydata was
        specified as a dict.

        plot_kw_keys = ['label', 'color', 'marker', 'markersize',
            'markeredgewidth', 'linewidth', 'linestyle']

    Returns:
        matplotlib.axes.Axes: ax : the axes that was drawn on

    References:
        matplotlib.org/examples/api/barchart_demo.html

    Example:
        >>> import kwplot
        >>> kwplot.autompl()
        >>> # The new way to use multi_plot is to pass ydata as a dict of lists
        >>> ydata = {
        >>>     'spamΣ': [1, 1, 2, 3, 5, 8, 13],
        >>>     'eggs': [3, 3, 3, 3, 3, np.nan, np.nan],
        >>>     'jamµ': [5, 3, np.nan, 1, 2, np.nan, np.nan],
        >>>     'pram': [4, 2, np.nan, 0, 0, np.nan, 1],
        >>> }
        >>> ax = kwplot.multi_plot(ydata=ydata, title='ΣΣΣµµµ',
        >>>                      xlabel='\nfdsΣΣΣµµµ', linestyle='--')
        >>> kwplot.show_if_requested()

    Example:
        >>> # Old way to use multi_plot is a list of lists
        >>> import kwplot
        >>> kwplot.autompl()
        >>> xdata = [1, 2, 3, 4, 5]
        >>> ydata_list = [[1, 2, 3, 4, 5], [3, 3, 3, 3, 3], [5, 4, np.nan, 2, 1], [4, 3, np.nan, 1, 0]]
        >>> kwargs = {'label': ['spamΣ', 'eggs', 'jamµ', 'pram'],  'linestyle': '-'}
        >>> #ax = multi_plot(xdata, ydata_list, title='$\phi_1(\\vec{x})$', xlabel='\nfds', **kwargs)
        >>> ax = multi_plot(xdata, ydata_list, title='ΣΣΣµµµ', xlabel='\nfdsΣΣΣµµµ', **kwargs)
        >>> kwplot.show_if_requested()

    Example:
        >>> # Simple way to use multi_plot is to pass xdata and ydata exactly
        >>> # like you would use plt.plot
        >>> import kwplot
        >>> kwplot.autompl()
        >>> ax = multi_plot([1, 2, 3], [4, 5, 6], fnum=4, label='foo')
        >>> kwplot.show_if_requested()

    Example:
        >>> import kwplot
        >>> kwplot.autompl()
        >>> xydata = {'a': ([0, 1, 2], [0, 1, 2]), 'b': ([0, 2, 4], [2, 1, 0])}
        >>> ax = kwplot.multi_plot(xydata=xydata, fnum=4)
        >>> kwplot.show_if_requested()

    Example:
        >>> import kwplot
        >>> kwplot.autompl()
        >>> ydata = {'a': [0, 1, 2], 'b': [1, 2, 1], 'c': [4, 4, 4, 3, 2]}
        >>> kwargs = {
        >>>     'spread': {'a': [.2, .3, .1], 'b': .2},
        >>>     'xlim': (-1, 5),
        >>>     'xticklabels': ['foo', 'bar'],
        >>>     'xtick_rotation': 90,
        >>> }
        >>> ax = kwplot.multi_plot(ydata=ydata, fnum=4, **kwargs)
        >>> kwplot.show_if_requested()

    Ignore:
        >>> import kwplot
        >>> kwplot.autompl()
        >>> ydata = {
        >>>     str(i): np.random.rand(100) + i for i in range(30)
        >>> }
        >>> ax = kwplot.multi_plot(ydata=ydata, fnum=1, doclf=True)
        >>> kwplot.show_if_requested()
    """
    import matplotlib as mpl
    from matplotlib import pyplot as plt

    # Initial integration with mpl rcParams standards
    mplrc = mpl.rcParams
    # mplrc.update({
    #     # 'legend.fontsize': custom_figure.LEGEND_SIZE,
    #     # 'legend.framealpha':
    #     # 'axes.titlesize': custom_figure.TITLE_SIZE,
    #     # 'axes.labelsize': custom_figure.LABEL_SIZE,
    #     # 'legend.facecolor': 'w',
    #     # 'font.family': 'sans-serif',
    #     # 'xtick.labelsize': custom_figure.TICK_SIZE,
    #     # 'ytick.labelsize': custom_figure.TICK_SIZE,
    # })
    if 'rcParams' in kwargs:
        mplrc = mplrc.copy()
        mplrc.update(kwargs['rcParams'])

    if xydata is not None:
        if xdata is not None or ydata is not None:
            raise ValueError('Cannot specify xydata with xdata or ydata')
        if isinstance(xydata, dict):
            xdata = ub.odict((k, np.array(xy[0])) for k, xy in xydata.items())
            ydata = ub.odict((k, np.array(xy[1])) for k, xy in xydata.items())
        else:
            raise ValueError('Only supports xydata as Dict at the moment')

    if bool('label' in kwargs) and bool('label_list' in kwargs):
        raise ValueError('Specify either label or label_list')

    if isinstance(ydata, dict):
        # Case where ydata is a dictionary
        if isinstance(xdata, six.string_types):
            # Special-er case where xdata is specified in ydata
            xkey = xdata
            ykeys = set(ydata.keys()) - {xkey}
            xdata = ydata[xkey]
        else:
            ykeys = list(ydata.keys())
        # Normalize input into ydata_list
        ydata_list = list(ub.take(ydata, ykeys))

        default_label_list = kwargs.pop('label', ykeys)
        kwargs['label_list'] = kwargs.get('label_list', default_label_list)
    else:
        # ydata should be a List[ndarray] or an ndarray
        ydata_list = ydata
        ykeys = None

    # allow ydata_list to be passed without a container
    if is_list_of_scalars(ydata_list):
        ydata_list = [np.array(ydata_list)]

    if xdata is None:
        xdata = list(range(max(map(len, ydata_list))))

    num_lines = len(ydata_list)

    # Transform xdata into xdata_list
    if isinstance(xdata, dict):
        xdata_list = [np.array(xdata[k], copy=True) for k in ykeys]
    elif is_list_of_lists(xdata):
        xdata_list = [np.array(xd, copy=True) for xd in xdata]
    else:
        xdata_list = [np.array(xdata, copy=True)] * num_lines

    fnum = mpl_core.ensure_fnum(kwargs.get('fnum', None))
    pnum = kwargs.get('pnum', None)
    kind = kwargs.get('kind', 'plot')
    transpose = kwargs.get('transpose', False)

    def parsekw_list(key,
                     kwargs,
                     num_lines=num_lines,
                     ykeys=ykeys,
                     default=ub.NoParam):
        """
        Return properties that corresponds with ydata_list.

        Searches kwargs for several keys based on the base key and finds either
        a scalar, list, or dict and coerces this into a list of properties that
        corresonds with the ydata_list.
        """
        if key in kwargs:
            val_list = kwargs[key]
        elif key + '_list' in kwargs:
            # warnings.warn('*_list is depricated, just use kwarg {}'.format(key))
            val_list = kwargs[key + '_list']
        elif key + 's' in kwargs:
            # hack, multiple ways to do something
            warnings.warn('*s depricated, just use kwarg {}'.format(key))
            val_list = kwargs[key + 's']
        else:
            val_list = None

        if val_list is not None:
            if isinstance(val_list, dict):
                # Extract propertly ordered dictionary values
                if ykeys is None:
                    raise ValueError(
                        'Kwarg {!r} was a dict, but ydata was not'.format(key))
                else:
                    if default is ub.NoParam:
                        val_list = [val_list[key] for key in ykeys]
                    else:
                        val_list = [
                            val_list.get(key, default) for key in ykeys
                        ]

            if not isinstance(val_list, list):
                # Coerce a scalar value into a list
                val_list = [val_list] * num_lines
        return val_list

    if kind == 'plot':
        if 'marker' not in kwargs:
            # kwargs['marker'] = mplrc['lines.marker']
            kwargs['marker'] = 'distinct'
            # kwargs['marker'] = 'cycle'

        if isinstance(kwargs['marker'], six.string_types):
            if kwargs['marker'] == 'distinct':
                kwargs['marker'] = mpl_core.distinct_markers(num_lines)
            elif kwargs['marker'] == 'cycle':
                # Note the length of marker and linestyle cycles should be
                # relatively prime.
                # https://matplotlib.org/api/markers_api.html
                marker_cycle = ['.', '*', 'x']
                kwargs['marker'] = [
                    marker_cycle[i % len(marker_cycle)]
                    for i in range(num_lines)
                ]
            # else:
            #     raise KeyError(kwargs['marker'])

        if 'linestyle' not in kwargs:
            # kwargs['linestyle'] = 'distinct'
            kwargs['linestyle'] = mplrc['lines.linestyle']
            # kwargs['linestyle'] = 'cycle'

        if isinstance(kwargs['linestyle'], six.string_types):
            if kwargs['linestyle'] == 'cycle':
                # https://matplotlib.org/gallery/lines_bars_and_markers/line_styles_reference.html
                linestyle_cycle = ['solid', 'dashed', 'dashdot', 'dotted']
                kwargs['linestyle'] = [
                    linestyle_cycle[i % len(linestyle_cycle)]
                    for i in range(num_lines)
                ]

    if 'color' not in kwargs:
        # kwargs['color'] = 'jet'
        # kwargs['color'] = 'gist_rainbow'
        kwargs['color'] = 'distinct'

    if isinstance(kwargs['color'], six.string_types):
        if kwargs['color'] == 'distinct':
            kwargs['color'] = mpl_core.distinct_colors(num_lines, randomize=0)
        else:
            cm = plt.get_cmap(kwargs['color'])
            kwargs['color'] = [cm(i / num_lines) for i in range(num_lines)]

    # Parse out arguments to ax.plot
    plot_kw_keys = [
        'label', 'color', 'marker', 'markersize', 'markeredgewidth',
        'linewidth', 'linestyle', 'alpha'
    ]
    # hackish / extra args that dont directly get passed to plt.plot
    extra_plot_kw_keys = ['spread_alpha', 'autolabel', 'edgecolor', 'fill']
    plot_kw_keys += extra_plot_kw_keys
    plot_ks_vals = [parsekw_list(key, kwargs) for key in plot_kw_keys]
    plot_list_kw = dict([(key, vals)
                         for key, vals in zip(plot_kw_keys, plot_ks_vals)
                         if vals is not None])

    if kind == 'plot':
        if 'spread_alpha' not in plot_list_kw:
            plot_list_kw['spread_alpha'] = [.2] * num_lines

    if kind == 'bar':
        # Remove non-bar kwargs
        for key in [
                'markeredgewidth', 'linewidth', 'marker', 'markersize',
                'linestyle'
        ]:
            plot_list_kw.pop(key, None)

        stacked = kwargs.get('stacked', False)
        width_key = 'height' if transpose else 'width'
        if 'width_list' in kwargs:
            plot_list_kw[width_key] = kwargs['width_list']
        else:
            width = kwargs.get('width', .9)
            # if width is None:
            #     # HACK: need variable width
            #     # width = np.mean(np.diff(xdata_list[0]))
            #     width = .9
            if not stacked:
                width /= num_lines
            #plot_list_kw['orientation'] = ['horizontal'] * num_lines
            plot_list_kw[width_key] = [width] * num_lines

    spread_list = parsekw_list('spread', kwargs, default=None)

    # nest into a list of dicts for each line in the multiplot
    valid_keys = list(set(plot_list_kw.keys()) - set(extra_plot_kw_keys))
    valid_vals = list(ub.dict_take(plot_list_kw, valid_keys))
    plot_kw_list = [dict(zip(valid_keys, vals)) for vals in zip(*valid_vals)]

    extra_kw_keys = [key for key in extra_plot_kw_keys if key in plot_list_kw]
    extra_kw_vals = list(ub.dict_take(plot_list_kw, extra_kw_keys))
    extra_kw_list = [
        dict(zip(extra_kw_keys, vals)) for vals in zip(*extra_kw_vals)
    ]

    # Get passed in axes or setup a new figure
    ax = kwargs.get('ax', None)
    if ax is None:
        # NOTE: This is slow, can we speed it up?
        doclf = kwargs.get('doclf', False)
        fig = mpl_core.figure(fnum=fnum, pnum=pnum, docla=False, doclf=doclf)
        ax = fig.gca()
    else:
        plt.sca(ax)
        fig = ax.figure

    # +---------------
    # Draw plot lines
    ydata_list = [np.array(ydata) for ydata in ydata_list]

    if transpose:
        if kind == 'bar':
            plot_func = ax.barh
        elif kind == 'plot':

            def plot_func(_x, _y, **kw):
                return ax.plot(_y, _x, **kw)
    else:
        plot_func = getattr(ax, kind)  # usually ax.plot

    if len(ydata_list) > 0:
        # raise ValueError('no ydata')
        _iter = enumerate(
            zip_longest(xdata_list, ydata_list, plot_kw_list, extra_kw_list))
        for count, (_xdata, _ydata, plot_kw, extra_kw) in _iter:
            _ydata = _ydata[0:len(_xdata)]
            _xdata = _xdata[0:len(_ydata)]
            ymask = np.isfinite(_ydata)
            ydata_ = _ydata.compress(ymask)
            xdata_ = _xdata.compress(ymask)
            if kind == 'bar':
                if stacked:
                    # Plot bars on top of each other
                    xdata_ = xdata_
                else:
                    # Plot bars side by side
                    baseoffset = (width * num_lines) / 2
                    lineoffset = (width * count)
                    offset = baseoffset - lineoffset  # Fixeme for more histogram bars
                    xdata_ = xdata_ - offset
                # width_key = 'height' if transpose else 'width'
                # plot_kw[width_key] = np.diff(xdata)
            objs = plot_func(xdata_, ydata_, **plot_kw)

            if kind == 'bar':
                if extra_kw is not None and 'edgecolor' in extra_kw:
                    for rect in objs:
                        rect.set_edgecolor(extra_kw['edgecolor'])
                if extra_kw is not None and extra_kw.get('autolabel', False):
                    # FIXME: probably a more cannonical way to include bar
                    # autolabeling with tranpose support, but this is a hack that
                    # works for now
                    for rect in objs:
                        if transpose:
                            numlbl = width = rect.get_width()
                            xpos = width + (
                                (_xdata.max() - _xdata.min()) * .005)
                            ypos = rect.get_y() + rect.get_height() / 2.
                            ha, va = 'left', 'center'
                        else:
                            numlbl = height = rect.get_height()
                            xpos = rect.get_x() + rect.get_width() / 2.
                            ypos = 1.05 * height
                            ha, va = 'center', 'bottom'
                        barlbl = '%.3f' % (numlbl, )
                        ax.text(xpos, ypos, barlbl, ha=ha, va=va)

            if kind == 'plot' and extra_kw.get('fill', False):
                ax.fill_between(_xdata,
                                ydata_,
                                alpha=plot_kw.get('alpha', 1.0),
                                color=plot_kw.get('color',
                                                  None))  # , zorder=0)

            if spread_list is not None:
                # Plots a spread around plot lines usually indicating standard
                # deviation
                _xdata = np.array(_xdata)
                _spread = spread_list[count]
                if _spread is not None:
                    if not ub.iterable(_spread):
                        _spread = [_spread] * len(ydata_)
                    ydata_ave = np.array(ydata_)
                    y_data_dev = np.array(_spread)
                    y_data_max = ydata_ave + y_data_dev
                    y_data_min = ydata_ave - y_data_dev
                    ax = plt.gca()
                    spread_alpha = extra_kw['spread_alpha']
                    ax.fill_between(_xdata,
                                    y_data_min,
                                    y_data_max,
                                    alpha=spread_alpha,
                                    color=plot_kw.get('color',
                                                      None))  # , zorder=0)
        ydata = _ydata  # HACK
        xdata = _xdata  # HACK
    # L________________

    #max_y = max(np.max(y_data), max_y)
    #min_y = np.min(y_data) if min_y is None else min(np.min(y_data), min_y)

    if transpose:
        #xdata_list = ydata_list
        ydata = xdata

        # Hack / Fix any transpose issues
        def transpose_key(key):
            if key.startswith('x'):
                return 'y' + key[1:]
            elif key.startswith('y'):
                return 'x' + key[1:]
            elif key.startswith('num_x'):
                # hackier, fixme to use regex or something
                return 'num_y' + key[5:]
            elif key.startswith('num_y'):
                # hackier, fixme to use regex or something
                return 'num_x' + key[5:]
            else:
                return key

        kwargs = {transpose_key(key): val for key, val in kwargs.items()}

    # Setup axes labeling
    title = kwargs.get('title', None)
    xlabel = kwargs.get('xlabel', '')
    ylabel = kwargs.get('ylabel', '')

    def none_or_unicode(text):
        return None if text is None else ub.ensure_unicode(text)

    xlabel = none_or_unicode(xlabel)
    ylabel = none_or_unicode(ylabel)
    title = none_or_unicode(title)

    titlesize = kwargs.get('titlesize', mplrc['axes.titlesize'])
    labelsize = kwargs.get('labelsize', mplrc['axes.labelsize'])
    legendsize = kwargs.get('legendsize', mplrc['legend.fontsize'])
    xticksize = kwargs.get('ticksize', mplrc['xtick.labelsize'])
    yticksize = kwargs.get('ticksize', mplrc['ytick.labelsize'])
    family = kwargs.get('fontfamily', mplrc['font.family'])

    tickformat = kwargs.get('tickformat', None)
    ytickformat = kwargs.get('ytickformat', tickformat)
    xtickformat = kwargs.get('xtickformat', tickformat)

    # 'DejaVu Sans','Verdana', 'Arial'
    weight = kwargs.get('fontweight', None)
    if weight is None:
        weight = 'normal'

    labelkw = {
        'fontproperties':
        mpl.font_manager.FontProperties(weight=weight,
                                        family=family,
                                        size=labelsize)
    }
    ax.set_xlabel(xlabel, **labelkw)
    ax.set_ylabel(ylabel, **labelkw)

    tick_fontprop = mpl.font_manager.FontProperties(family=family,
                                                    weight=weight)

    if tick_fontprop is not None:
        # NOTE: This is slow, can we speed it up?
        for ticklabel in ax.get_xticklabels():
            ticklabel.set_fontproperties(tick_fontprop)
        for ticklabel in ax.get_yticklabels():
            ticklabel.set_fontproperties(tick_fontprop)
    if xticksize is not None:
        for ticklabel in ax.get_xticklabels():
            ticklabel.set_fontsize(xticksize)
    if yticksize is not None:
        for ticklabel in ax.get_yticklabels():
            ticklabel.set_fontsize(yticksize)

    if xtickformat is not None:
        # mpl.ticker.StrMethodFormatter  # new style
        # mpl.ticker.FormatStrFormatter  # old style
        ax.xaxis.set_major_formatter(
            mpl.ticker.FormatStrFormatter(xtickformat))
    if ytickformat is not None:
        ax.yaxis.set_major_formatter(
            mpl.ticker.FormatStrFormatter(ytickformat))

    xtick_kw = ytick_kw = {
        'width': kwargs.get('tickwidth', None),
        'length': kwargs.get('ticklength', None),
    }
    xtick_kw = {k: v for k, v in xtick_kw.items() if v is not None}
    ytick_kw = {k: v for k, v in ytick_kw.items() if v is not None}
    ax.xaxis.set_tick_params(**xtick_kw)
    ax.yaxis.set_tick_params(**ytick_kw)

    #ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%d'))

    # Setup axes limits
    if 'xlim' in kwargs:
        xlim = kwargs['xlim']
        if xlim is not None:
            if 'xmin' not in kwargs and 'xmax' not in kwargs:
                kwargs['xmin'] = xlim[0]
                kwargs['xmax'] = xlim[1]
            else:
                raise ValueError('use xmax, xmin instead of xlim')
    if 'ylim' in kwargs:
        ylim = kwargs['ylim']
        if ylim is not None:
            if 'ymin' not in kwargs and 'ymax' not in kwargs:
                kwargs['ymin'] = ylim[0]
                kwargs['ymax'] = ylim[1]
            else:
                raise ValueError('use ymax, ymin instead of ylim')

    xmin = kwargs.get('xmin', ax.get_xlim()[0])
    xmax = kwargs.get('xmax', ax.get_xlim()[1])
    ymin = kwargs.get('ymin', ax.get_ylim()[0])
    ymax = kwargs.get('ymax', ax.get_ylim()[1])

    text_type = six.text_type

    if text_type(xmax) == 'data':
        xmax = max([xd.max() for xd in xdata_list])
    if text_type(xmin) == 'data':
        xmin = min([xd.min() for xd in xdata_list])

    # Setup axes ticks
    num_xticks = kwargs.get('num_xticks', None)
    num_yticks = kwargs.get('num_yticks', None)

    if num_xticks is not None:
        if xdata.dtype.kind == 'i':
            xticks = np.linspace(np.ceil(xmin), np.floor(xmax),
                                 num_xticks).astype(np.int32)
        else:
            xticks = np.linspace((xmin), (xmax), num_xticks)
        ax.set_xticks(xticks)
    if num_yticks is not None:
        if ydata.dtype.kind == 'i':
            yticks = np.linspace(np.ceil(ymin), np.floor(ymax),
                                 num_yticks).astype(np.int32)
        else:
            yticks = np.linspace((ymin), (ymax), num_yticks)
        ax.set_yticks(yticks)

    force_xticks = kwargs.get('force_xticks', None)
    if force_xticks is not None:
        xticks = np.array(sorted(ax.get_xticks().tolist() + force_xticks))
        ax.set_xticks(xticks)

    yticklabels = kwargs.get('yticklabels', None)
    if yticklabels is not None:
        # Hack ONLY WORKS WHEN TRANSPOSE = True
        # Overrides num_yticks
        missing_labels = max(len(ydata) - len(yticklabels), 0)
        yticklabels_ = yticklabels + [''] * missing_labels
        ax.set_yticks(ydata)
        ax.set_yticklabels(yticklabels_)

    xticklabels = kwargs.get('xticklabels', None)
    if xticklabels is not None:
        # Overrides num_xticks
        missing_labels = max(len(xdata) - len(xticklabels), 0)
        xticklabels_ = xticklabels + [''] * missing_labels
        ax.set_xticks(xdata)
        ax.set_xticklabels(xticklabels_)

    xticks = kwargs.get('xticks', None)
    if xticks is not None:
        ax.set_xticks(xticks)

    yticks = kwargs.get('yticks', None)
    if yticks is not None:
        ax.set_yticks(yticks)

    xtick_rotation = kwargs.get('xtick_rotation', None)
    if xtick_rotation is not None:
        [lbl.set_rotation(xtick_rotation) for lbl in ax.get_xticklabels()]
    ytick_rotation = kwargs.get('ytick_rotation', None)
    if ytick_rotation is not None:
        [lbl.set_rotation(ytick_rotation) for lbl in ax.get_yticklabels()]

    # Axis padding
    xpad = kwargs.get('xpad', None)
    ypad = kwargs.get('ypad', None)
    xpad_factor = kwargs.get('xpad_factor', None)
    ypad_factor = kwargs.get('ypad_factor', None)
    if xpad is None and xpad_factor is not None:
        xpad = (xmax - xmin) * xpad_factor
    if ypad is None and ypad_factor is not None:
        ypad = (ymax - ymin) * ypad_factor
    xpad = 0 if xpad is None else xpad
    ypad = 0 if ypad is None else ypad
    ypad_high = kwargs.get('ypad_high', ypad)
    ypad_low = kwargs.get('ypad_low', ypad)
    xpad_high = kwargs.get('xpad_high', xpad)
    xpad_low = kwargs.get('xpad_low', xpad)
    xmin, xmax = (xmin - xpad_low), (xmax + xpad_high)
    ymin, ymax = (ymin - ypad_low), (ymax + ypad_high)
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    xscale = kwargs.get('xscale', None)
    yscale = kwargs.get('yscale', None)
    if yscale is not None:
        ax.set_yscale(yscale)
    if xscale is not None:
        ax.set_xscale(xscale)

    gridlinestyle = kwargs.get('gridlinestyle', None)
    gridlinewidth = kwargs.get('gridlinewidth', None)
    gridlines = ax.get_xgridlines() + ax.get_ygridlines()
    if gridlinestyle:
        for line in gridlines:
            line.set_linestyle(gridlinestyle)
    if gridlinewidth:
        for line in gridlines:
            line.set_linewidth(gridlinewidth)

    # Setup title
    if title is not None:
        titlekw = {
            'fontproperties':
            mpl.font_manager.FontProperties(family=family,
                                            weight=weight,
                                            size=titlesize)
        }
        ax.set_title(title, **titlekw)

    use_legend = kwargs.get('use_legend', 'label' in valid_keys)
    legend_loc = kwargs.get('legend_loc', mplrc['legend.loc'])
    legend_alpha = kwargs.get('legend_alpha', mplrc['legend.framealpha'])
    if use_legend:
        legendkw = {
            'alpha':
            legend_alpha,
            'fontproperties':
            mpl.font_manager.FontProperties(family=family,
                                            weight=weight,
                                            size=legendsize)
        }
        mpl_core.legend(loc=legend_loc, ax=ax, **legendkw)

    figtitle = kwargs.get('figtitle', None)

    if figtitle is not None:
        # mplrc['figure.titlesize'] TODO?
        mpl_core.set_figtitle(figtitle,
                              fontfamily=family,
                              fontweight=weight,
                              size=kwargs.get('figtitlesize'))

    # TODO: return better info
    return ax
Exemple #7
0
def non_max_supression(tlbr,
                       scores,
                       thresh,
                       bias=0.0,
                       classes=None,
                       impl='auto',
                       device_id=None):
    """
    Non-Maximum Suppression - remove redundant bounding boxes

    Args:
        tlbr (ndarray[float32]): Nx4 boxes in tlbr format
        scores (ndarray[float32]): score for each bbox
        thresh (float): iou threshold.
            Boxes are removed if they overlap greater than this threshold
            (i.e. Boxes are removed if iou > threshold).
            Thresh = 0 is the most strict, resulting in the fewest boxes, and 1
            is the most permissive resulting in the most.
        bias (float): bias for iou computation either 0 or 1
        classes (ndarray[int64] or None): integer classes.
            If specified NMS is done on a perclass basis.
        impl (str): implementation can be auto, python, cython_cpu, or gpu
        device_id (int): used if impl is gpu, device id to work on. If not
            specified `torch.cuda.current_device()` is used.

    Notes:
        Using impl='cython_gpu' may result in an CUDA memory error that is not exposed
        to the python processes. In other words your program will hard crash if
        impl='cython_gpu', and you feed it too many bounding boxes. Ideally this will
        be fixed in the future.

    References:
        https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/cython_nms.pyx
        https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
        https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx <- TODO

    CommandLine:
        xdoctest -m ~/code/kwimage/kwimage/algo/algo_nms.py non_max_supression

    Example:
        >>> from kwimage.algo.algo_nms import *
        >>> from kwimage.algo.algo_nms import _impls
        >>> tlbr = np.array([
        >>>     [0, 0, 100, 100],
        >>>     [100, 100, 10, 10],
        >>>     [10, 10, 100, 100],
        >>>     [50, 50, 100, 100],
        >>> ], dtype=np.float32)
        >>> scores = np.array([.1, .5, .9, .1])
        >>> keep = non_max_supression(tlbr, scores, thresh=0.5, impl='numpy')
        >>> print('keep = {!r}'.format(keep))
        >>> assert keep == [2, 1, 3]
        >>> thresh = 0.0
        >>> non_max_supression(tlbr, scores, thresh, impl='numpy')
        >>> if 'numpy' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='numpy')
        >>>     assert list(keep) == [2, 1]
        >>> if 'cython_cpu' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='cython_cpu')
        >>>     assert list(keep) == [2, 1]
        >>> if 'cython_gpu' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='cython_gpu')
        >>>     assert list(keep) == [2, 1]
        >>> if 'torch' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='torch')
        >>>     assert set(keep.tolist()) == {2, 1}
        >>> if 'torchvision' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='torchvision')  # note torchvision has no bias
        >>>     assert list(keep) == [2]
        >>> thresh = 1.0
        >>> if 'numpy' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='numpy')
        >>>     assert list(keep) == [2, 1, 3, 0]
        >>> if 'cython_cpu' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='cython_cpu')
        >>>     assert list(keep) == [2, 1, 3, 0]
        >>> if 'cython_gpu' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='cython_gpu')
        >>>     assert list(keep) == [2, 1, 3, 0]
        >>> if 'torch' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='torch')
        >>>     assert set(keep.tolist()) == {2, 1, 3, 0}
        >>> if 'torchvision' in available_nms_impls():
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl='torchvision')  # note torchvision has no bias
        >>>     assert set(kwarray.ArrayAPI.tolist(keep)) == {2, 1, 3, 0}

    Example:
        >>> import ubelt as ub
        >>> tlbr = np.array([
        >>>     [0, 0, 100, 100],
        >>>     [100, 100, 10, 10],
        >>>     [10, 10, 100, 100],
        >>>     [50, 50, 100, 100],
        >>>     [100, 100, 150, 101],
        >>>     [120, 100, 180, 101],
        >>>     [150, 100, 200, 101],
        >>> ], dtype=np.float32)
        >>> scores = np.linspace(0, 1, len(tlbr))
        >>> thresh = .2
        >>> solutions = {}
        >>> if not _impls._funcs:
        >>>     _impls._lazy_init()
        >>> for impl in _impls._funcs:
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl=impl)
        >>>     solutions[impl] = sorted(keep)
        >>> assert 'numpy' in solutions
        >>> print('solutions = {}'.format(ub.repr2(solutions, nl=1)))
        >>> assert ub.allsame(solutions.values())

    CommandLine:
        xdoctest -m ~/code/kwimage/kwimage/algo/algo_nms.py non_max_supression

    Example:
        >>> import ubelt as ub
        >>> # Check that zero-area boxes are ok
        >>> tlbr = np.array([
        >>>     [0, 0, 0, 0],
        >>>     [0, 0, 0, 0],
        >>>     [10, 10, 10, 10],
        >>> ], dtype=np.float32)
        >>> scores = np.array([1, 2, 3], dtype=np.float32)
        >>> thresh = .2
        >>> solutions = {}
        >>> if not _impls._funcs:
        >>>     _impls._lazy_init()
        >>> for impl in _impls._funcs:
        >>>     keep = non_max_supression(tlbr, scores, thresh, impl=impl)
        >>>     solutions[impl] = sorted(keep)
        >>> assert 'numpy' in solutions
        >>> print('solutions = {}'.format(ub.repr2(solutions, nl=1)))
        >>> assert ub.allsame(solutions.values())
    """

    if impl == 'cpu':
        import warnings
        warnings.warn('impl="cpu" is deprecated use impl="cython_cpu" instead',
                      DeprecationWarning)
        impl = 'cython_impl'
    elif impl == 'gpu':
        import warnings
        warnings.warn('impl="gpu" is deprecated use impl="cython_gpu" instead',
                      DeprecationWarning)
        impl = 'cython_gpu'
    elif impl == 'py':
        import warnings
        warnings.warn('impl="py" is deprecated use impl="numpy" instead',
                      DeprecationWarning)
        impl = 'numpy'

    if not _impls._funcs:
        _impls._lazy_init()

    if tlbr.shape[0] == 0:
        return []

    if impl == 'auto':
        is_tensor = torch is not None and torch.is_tensor(tlbr)
        num = len(tlbr)
        if is_tensor:
            if tlbr.device.type == 'cuda':
                code = 'tensor0'
            else:
                code = 'tensor'
        else:
            code = 'ndarray'
        valid = _impls._valid
        impl = _heuristic_auto_nms_impl(code, num, valid)
        # print('impl._valid = {!r}'.format(_impls._valid))
        # print('impl = {!r}'.format(impl))

    elif ub.iterable(impl):
        # if impl is iterable, it is a preference order
        found = False
        for item in impl:
            if item in _impls._funcs:
                impl = item
                found = True
                break
        if not found:
            raise KeyError('Unknown impls={}'.format(impl))

    if classes is not None:
        keep = []
        for idxs in ub.group_items(range(len(classes)), classes).values():
            # cls_tlbr = tlbr.take(idxs, axis=0)
            # cls_scores = scores.take(idxs, axis=0)
            cls_tlbr = tlbr[idxs]
            cls_scores = scores[idxs]
            cls_keep = non_max_supression(cls_tlbr,
                                          cls_scores,
                                          thresh=thresh,
                                          bias=bias,
                                          impl=impl)
            keep.extend(list(ub.take(idxs, cls_keep)))
        return keep
    else:

        if impl == 'numpy':
            api = kwarray.ArrayAPI.coerce(tlbr)
            tlbr = api.numpy(tlbr)
            scores = api.numpy(scores)
            func = _impls._funcs['numpy']
            keep = func(tlbr, scores, thresh, bias=float(bias))
        elif impl == 'torch' or impl == 'torchvision':
            api = kwarray.ArrayAPI.coerce(tlbr)
            tlbr = api.tensor(tlbr).float()
            scores = api.tensor(scores).float()
            # Default output of torch impl is a mask
            if impl == 'torchvision':
                # if bias != 1:
                #     warnings.warn('torchvision only supports bias==1')
                func = _impls._funcs['torchvision']
                # Torchvision returns indices
                keep = func(tlbr, scores, iou_threshold=thresh)
            else:
                func = _impls._funcs['torch']
                flags = func(tlbr, scores, thresh=thresh, bias=float(bias))
                keep = torch.nonzero(flags).view(-1)

            # Ensure than input type is the same as output type
            keep = api.numpy(keep)
        else:
            # TODO: it would be nice to be able to pass torch tensors here
            nms = _impls._funcs[impl]
            tlbr = kwarray.ArrayAPI.numpy(tlbr)
            scores = kwarray.ArrayAPI.numpy(scores)
            tlbr = tlbr.astype(np.float32)
            scores = scores.astype(np.float32)
            if impl == 'cython_gpu':
                # TODO: if the data is already on a torch GPU can we just
                # use it?
                # HACK: we should parameterize which device is used
                if device_id is None:
                    device_id = torch.cuda.current_device()
                keep = nms(tlbr,
                           scores,
                           float(thresh),
                           bias=float(bias),
                           device_id=device_id)
            elif impl == 'cython_cpu':
                keep = nms(tlbr, scores, float(thresh), bias=float(bias))
            else:
                raise KeyError(impl)
        return keep
Exemple #8
0
def daq_spatial_nms(tlbr,
                    scores,
                    diameter,
                    thresh,
                    max_depth=6,
                    stop_size=2048,
                    recsize=2048,
                    impl='auto',
                    device_id=None):
    """
    Divide and conquor speedup non-max-supression algorithm for when bboxes
    have a known max size

    Args:
        tlbr (ndarray): boxes in (tlx, tly, brx, bry) format

        scores (ndarray): scores of each box

        diameter (int or Tuple[int, int]): Distance from split point to
            consider rectification. If specified as an integer, then number
            is used for both height and width. If specified as a tuple, then
            dims are assumed to be in [height, width] format.

        thresh (float): iou threshold. Boxes are removed if they overlap
            greater than this threshold. 0 is the most strict, resulting in the
            fewest boxes, and 1 is the most permissive resulting in the most.

        max_depth (int): maximum number of times we can divide and conquor

        stop_size (int): number of boxes that triggers full NMS computation

        recsize (int): number of boxes that triggers full NMS recombination

        impl (str): algorithm to use

    LookInfo:
        # Didn't read yet but it seems similar
        http://www.cyberneum.de/fileadmin/user_upload/files/publications/CVPR2010-Lampert_[0].pdf

        https://www.researchgate.net/publication/220929789_Efficient_Non-Maximum_Suppression

        # This seems very similar
        https://projet.liris.cnrs.fr/m2disco/pub/Congres/2006-ICPR/DATA/C03_0406.PDF

    Example:
        >>> import kwimage
        >>> # Make a bunch of boxes with the same width and height
        >>> #boxes = kwimage.Boxes.random(230397, scale=1000, format='cxywh')
        >>> boxes = kwimage.Boxes.random(237, scale=1000, format='cxywh')
        >>> boxes.data.T[2] = 10
        >>> boxes.data.T[3] = 10
        >>> #
        >>> tlbr = boxes.to_tlbr().data.astype(np.float32)
        >>> scores = np.arange(0, len(tlbr)).astype(np.float32)
        >>> #
        >>> n_megabytes = (tlbr.size * tlbr.dtype.itemsize) / (2 ** 20)
        >>> print('n_megabytes = {!r}'.format(n_megabytes))
        >>> #
        >>> thresh = iou_thresh = 0.01
        >>> impl = 'auto'
        >>> max_depth = 20
        >>> diameter = 10
        >>> stop_size = 2000
        >>> recsize = 500
        >>> #
        >>> import ubelt as ub
        >>> #
        >>> with ub.Timer(label='daq'):
        >>>     keep1 = daq_spatial_nms(tlbr, scores,
        >>>         diameter=diameter, thresh=thresh, max_depth=max_depth,
        >>>         stop_size=stop_size, recsize=recsize, impl=impl)
        >>> #
        >>> with ub.Timer(label='full'):
        >>>     keep2 = non_max_supression(tlbr, scores,
        >>>         thresh=thresh, impl=impl)
        >>> #
        >>> # Due to the greedy nature of the algorithm, there will be slight
        >>> # differences in results, but they will be mostly similar.
        >>> similarity = len(set(keep1) & set(keep2)) / len(set(keep1) | set(keep2))
        >>> print('similarity = {!r}'.format(similarity))
    """
    def _rectify(tlbr, both_keep, needs_rectify):
        if len(needs_rectify) == 0:
            keep = sorted(both_keep)
        else:
            nr_arr = np.array(sorted(needs_rectify))
            nr = needs_rectify
            bk = set(both_keep)
            rectified_keep = non_max_supression(tlbr[nr_arr],
                                                scores[nr_arr],
                                                thresh=thresh,
                                                impl=impl,
                                                device_id=device_id)
            rk = set(nr_arr[rectified_keep])
            keep = sorted((bk - nr) | rk)
        return keep

    def _recurse(tlbr, scores, dim, depth, diameter_wh):
        """
        Args:
            dim (int): flips between 0 and 1
            depth (int): recursion depth
        """
        # print('recurse')
        n_boxes = len(tlbr)
        if depth >= max_depth or n_boxes < stop_size:
            # print('n_boxes = {!r}'.format(n_boxes))
            # print('depth = {!r}'.format(depth))
            # print('stop')
            keep = non_max_supression(tlbr, scores, thresh=thresh, impl=impl)
            both_keep = sorted(keep)
            needs_rectify = set()
        else:
            # Break up the NMS into two subproblems.
            middle = np.median(tlbr.T[dim])
            left_flags = tlbr.T[dim] < middle
            right_flags = ~left_flags

            left_idxs = np.where(left_flags)[0]
            right_idxs = np.where(right_flags)[0]

            left_scores = scores[left_idxs]
            left_tlbr = tlbr[left_idxs]

            right_scores = scores[right_idxs]
            right_tlbr = tlbr[right_idxs]

            next_depth = depth + 1
            next_dim = 1 - dim

            # Solve each subproblem
            left_keep_, lrec_ = _recurse(left_tlbr,
                                         left_scores,
                                         depth=next_depth,
                                         dim=next_dim,
                                         diameter_wh=diameter_wh)

            right_keep_, rrec_ = _recurse(right_tlbr,
                                          right_scores,
                                          depth=next_depth,
                                          dim=next_dim,
                                          diameter_wh=diameter_wh)

            # Recombine the results (note that because we have a diameter_wh,
            # we have to check less results)
            rrec = set(right_idxs[sorted(rrec_)])
            lrec = set(left_idxs[sorted(lrec_)])

            left_keep = left_idxs[left_keep_]
            right_keep = right_idxs[right_keep_]

            both_keep = np.hstack([left_keep, right_keep])
            both_keep.sort()

            dist_to_middle = np.abs(tlbr[both_keep].T[dim] - middle)

            # Find all surviving boxes that are close to the midpoint.  We will
            # need to recheck these because they may overlap, but they also may
            # have been split into different subproblems.
            rectify_flags = dist_to_middle < diameter_wh[dim]

            needs_rectify = set(both_keep[rectify_flags])
            needs_rectify.update(rrec)
            needs_rectify.update(lrec)

            nrec = len(needs_rectify)
            # print('nrec = {!r}'.format(nrec))
            if nrec > recsize:
                both_keep = _rectify(tlbr, both_keep, needs_rectify)
                needs_rectify = set()
        return both_keep, needs_rectify

    if not ub.iterable(diameter):
        diameter_wh = [diameter, diameter]
    else:
        diameter_wh = diameter[::-1]

    depth = 0
    dim = 0
    both_keep, needs_rectify = _recurse(tlbr,
                                        scores,
                                        dim=dim,
                                        depth=depth,
                                        diameter_wh=diameter_wh)
    keep = _rectify(tlbr, both_keep, needs_rectify)
    return keep
Exemple #9
0
def _padded_slice_embed(in_slice, data_dims, pad=None):
    """
    Embeds a "padded-slice" inside known data dimension.

    Returns the valid data portion of the slice with extra padding for regions
    outside of the available dimension.

    Given a slices for each dimension, image dimensions, and a padding get the
    corresponding slice from the image and any extra padding needed to achieve
    the requested window size.

    Args:
        in_slice (Tuple[slice]):
            a tuple of slices for to apply to data data dimension.
        data_dims (Tuple[int]):
            n-dimension data sizes (e.g. 2d height, width)
        pad (tuple): (List[int|Tuple]):
            extra pad applied to (left and right) / (both) sides of each slice
            dim

    Returns:
        Tuple:
            data_slice - Tuple[slice] a slice that can be applied to an array
                with with shape `data_dims`. This slice will not correspond to
                the full window size if the requested slice is out of bounds.
            extra_padding - extra padding needed after slicing to achieve
                the requested window size.

    Example:
        >>> # Case where slice is inside the data dims on left edge
        >>> from kwimage.im_core import *  # NOQA
        >>> in_slice = (slice(0, 10), slice(0, 10))
        >>> data_dims  = [300, 300]
        >>> pad        = [10, 5]
        >>> a, b = _padded_slice_embed(in_slice, data_dims, pad)
        >>> print('data_slice = {!r}'.format(a))
        >>> print('extra_padding = {!r}'.format(b))
        data_slice = (slice(0, 20, None), slice(0, 15, None))
        extra_padding = [(10, 0), (5, 0)]

    Example:
        >>> # Case where slice is bigger than the image
        >>> in_slice = (slice(-10, 400), slice(-10, 400))
        >>> data_dims  = [300, 300]
        >>> pad        = [10, 5]
        >>> a, b = _padded_slice_embed(in_slice, data_dims, pad)
        >>> print('data_slice = {!r}'.format(a))
        >>> print('extra_padding = {!r}'.format(b))
        data_slice = (slice(0, 300, None), slice(0, 300, None))
        extra_padding = [(20, 110), (15, 105)]

    Example:
        >>> # Case where slice is inside than the image
        >>> in_slice = (slice(10, 40), slice(10, 40))
        >>> data_dims  = [300, 300]
        >>> pad        = None
        >>> a, b = _padded_slice_embed(in_slice, data_dims, pad)
        >>> print('data_slice = {!r}'.format(a))
        >>> print('extra_padding = {!r}'.format(b))
        data_slice = (slice(10, 40, None), slice(10, 40, None))
        extra_padding = [(0, 0), (0, 0)]
    """
    low_dims = [sl.start for sl in in_slice]
    high_dims = [sl.stop for sl in in_slice]

    # Determine the real part of the image that can be sliced out
    data_slice_st = []
    extra_padding = []
    if pad is None:
        pad = 0
    if isinstance(pad, int):
        pad = [pad] * len(data_dims)
    # Normalize to left/right pad value for each dim
    pad_slice = [p if ub.iterable(p) else [p, p] for p in pad]

    # Determine the real part of the image that can be sliced out
    for D_img, d_low, d_high, d_pad in zip(data_dims, low_dims, high_dims,
                                           pad_slice):
        if d_low > d_high:
            raise ValueError('d_low > d_high: {} > {}'.format(d_low, d_high))
        # Determine where the bounds would be if the image size was inf
        raw_low = d_low - d_pad[0]
        raw_high = d_high + d_pad[1]
        # Clip the slice positions to the real part of the image
        sl_low = min(D_img, max(0, raw_low))
        sl_high = min(D_img, max(0, raw_high))
        data_slice_st.append((sl_low, sl_high))

        # Add extra padding when the window extends past the real part
        low_diff = sl_low - raw_low
        high_diff = raw_high - sl_high

        # Hand the case where both raw coordinates are out of bounds
        extra_low = max(0, low_diff + min(0, high_diff))
        extra_high = max(0, high_diff + min(0, low_diff))
        extra = (extra_low, extra_high)
        extra_padding.append(extra)

    data_slice = tuple(slice(s, t) for s, t in data_slice_st)
    return data_slice, extra_padding
Exemple #10
0
def _assign_confusion_vectors(true_dets,
                              pred_dets,
                              bg_weight=1.0,
                              iou_thresh=0.5,
                              bg_cidx=-1,
                              bias=0.0,
                              classes=None,
                              compat='all',
                              prioritize='iou',
                              ignore_classes='ignore',
                              max_dets=None):
    """
    Create confusion vectors for detections by assigning to ground true boxes

    Given predictions and truth for an image return (y_pred, y_true,
    y_score), which is suitable for sklearn classification metrics

    Args:
        true_dets (Detections):
            groundtruth with boxes, classes, and weights

        pred_dets (Detections):
            predictions with boxes, classes, and scores

        iou_thresh (float, default=0.5):
            bounding box overlap iou threshold required for assignment

        bias (float, default=0.0):
            for computing bounding box overlap, either 1 or 0

        gids (List[int], default=None):
            which subset of images ids to compute confusion metrics on. If
            not specified all images are used.

        compat (str, default='all'):
            can be ('ancestors' | 'mutex' | 'all').  determines which pred
            boxes are allowed to match which true boxes. If 'mutex', then
            pred boxes can only match true boxes of the same class. If
            'ancestors', then pred boxes can match true boxes that match or
            have a coarser label. If 'all', then any pred can match any
            true, regardless of its category label.

        prioritize (str, default='iou'):
            can be ('iou' | 'class' | 'correct') determines which box to
            assign to if mutiple true boxes overlap a predicted box.  if
            prioritize is iou, then the true box with maximum iou (above
            iou_thresh) will be chosen.  If prioritize is class, then it will
            prefer matching a compatible class above a higher iou. If
            prioritize is correct, then ancestors of the true class are
            preferred over descendents of the true class, over unreleated
            classes.

        bg_cidx (int, default=-1):
            The index of the background class.  The index used in the truth
            column when a predicted bounding box does not match any true
            bounding box.

        classes (List[str] | kwcoco.CategoryTree):
            mapping from class indices to class names. Can also contain class
            heirarchy information.

        ignore_classes (str | List[str]):
            class name(s) indicating ignore regions

        max_dets (int): maximum number of detections to consider

    TODO:
        - [ ] This is a bottleneck function. An implementation in C / C++ /
        Cython would likely improve the overall system.

        - [ ] Implement crowd truth. Allow multiple predictions to match any
              truth objet marked as "iscrowd".

    Returns:
        dict: with relevant confusion vectors. This keys of this dict can be
            interpreted as columns of a data frame. The `txs` / `pxs` columns
            represent the indexes of the true / predicted annotations that were
            assigned as matching. Additionally each row also contains the true
            and predicted class index, the predicted score, the true weight and
            the iou of the true and predicted boxes. A `txs` value of -1 means
            that the predicted box was not assigned to a true annotation and a
            `pxs` value of -1 means that the true annotation was not assigne to
            any predicted annotation.

    Example:
        >>> # xdoctest: +REQUIRES(module:pandas)
        >>> import pandas as pd
        >>> import kwimage
        >>> # Given a raw numpy representation construct Detection wrappers
        >>> true_dets = kwimage.Detections(
        >>>     boxes=kwimage.Boxes(np.array([
        >>>         [ 0,  0, 10, 10], [10,  0, 20, 10],
        >>>         [10,  0, 20, 10], [20,  0, 30, 10]]), 'tlbr'),
        >>>     weights=np.array([1, 0, .9, 1]),
        >>>     class_idxs=np.array([0, 0, 1, 2]))
        >>> pred_dets = kwimage.Detections(
        >>>     boxes=kwimage.Boxes(np.array([
        >>>         [6, 2, 20, 10], [3,  2, 9, 7],
        >>>         [3,  9, 9, 7],  [3,  2, 9, 7],
        >>>         [2,  6, 7, 7],  [20,  0, 30, 10]]), 'tlbr'),
        >>>     scores=np.array([.5, .5, .5, .5, .5, .5]),
        >>>     class_idxs=np.array([0, 0, 1, 2, 0, 1]))
        >>> bg_weight = 1.0
        >>> compat = 'all'
        >>> iou_thresh = 0.5
        >>> bias = 0.0
        >>> import kwcoco
        >>> classes = kwcoco.CategoryTree.from_mutex(list(range(3)))
        >>> bg_cidx = -1
        >>> y = _assign_confusion_vectors(true_dets, pred_dets, bias=bias,
        >>>                               bg_weight=bg_weight, iou_thresh=iou_thresh,
        >>>                               compat=compat)
        >>> y = pd.DataFrame(y)
        >>> print(y)  # xdoc: +IGNORE_WANT
           pred  true  score  weight     iou  txs  pxs
        0     1     2 0.5000  1.0000  1.0000    3    5
        1     0    -1 0.5000  1.0000 -1.0000   -1    4
        2     2    -1 0.5000  1.0000 -1.0000   -1    3
        3     1    -1 0.5000  1.0000 -1.0000   -1    2
        4     0    -1 0.5000  1.0000 -1.0000   -1    1
        5     0     0 0.5000  0.0000  0.6061    1    0
        6    -1     0 0.0000  1.0000 -1.0000    0   -1
        7    -1     1 0.0000  0.9000 -1.0000    2   -1

    Ignore:
        from xinspect.dynamic_kwargs import get_func_kwargs
        globals().update(get_func_kwargs(_assign_confusion_vectors))

    Example:
        >>> # xdoctest: +REQUIRES(module:pandas)
        >>> import pandas as pd
        >>> from kwcoco.metrics import DetectionMetrics
        >>> dmet = DetectionMetrics.demo(nimgs=1, nclasses=8,
        >>>                              nboxes=(0, 20), n_fp=20,
        >>>                              box_noise=.2, cls_noise=.3)
        >>> classes = dmet.classes
        >>> gid = 0
        >>> true_dets = dmet.true_detections(gid)
        >>> pred_dets = dmet.pred_detections(gid)
        >>> y = _assign_confusion_vectors(true_dets, pred_dets,
        >>>                               classes=dmet.classes,
        >>>                               compat='all', prioritize='class')
        >>> y = pd.DataFrame(y)
        >>> print(y)  # xdoc: +IGNORE_WANT
        >>> y = _assign_confusion_vectors(true_dets, pred_dets,
        >>>                               classes=dmet.classes,
        >>>                               compat='ancestors', iou_thresh=.5)
        >>> y = pd.DataFrame(y)
        >>> print(y)  # xdoc: +IGNORE_WANT
    """
    import kwarray
    valid_compat_keys = {'ancestors', 'mutex', 'all'}
    if compat not in valid_compat_keys:
        raise KeyError(compat)
    if classes is None and compat == 'ancestors':
        compat = 'mutex'

    if compat == 'mutex':
        prioritize = 'iou'

    # Group true boxes by class
    # Keep track which true boxes are unused / not assigned
    unique_tcxs, tgroupxs = kwarray.group_indices(true_dets.class_idxs)
    cx_to_txs = dict(zip(unique_tcxs, tgroupxs))

    unique_pcxs = np.array(sorted(set(pred_dets.class_idxs)))

    if classes is None:
        import kwcoco
        # Build mutually exclusive category tree
        all_cxs = sorted(
            set(map(int, unique_pcxs)) | set(map(int, unique_tcxs)))
        all_cxs = list(range(max(all_cxs) + 1))
        classes = kwcoco.CategoryTree.from_mutex(all_cxs)

    cx_to_ancestors = classes.idx_to_ancestor_idxs()

    if prioritize == 'iou':
        pdist_priority = None  # TODO: cleanup
    else:
        pdist_priority = _fast_pdist_priority(classes, prioritize)

    if compat == 'mutex':
        # assume classes are mutually exclusive if hierarchy is not given
        cx_to_matchable_cxs = {cx: [cx] for cx in unique_pcxs}
    elif compat == 'ancestors':
        cx_to_matchable_cxs = {
            cx: sorted([cx] + sorted(
                ub.take(classes.node_to_idx,
                        nx.ancestors(classes.graph, classes.idx_to_node[cx]))))
            for cx in unique_pcxs
        }
    elif compat == 'all':
        cx_to_matchable_cxs = {cx: unique_tcxs for cx in unique_pcxs}
    else:
        raise KeyError(compat)

    if compat == 'all':
        # In this case simply run the full pairwise iou
        common_true_idxs = np.arange(len(true_dets))
        cx_to_matchable_txs = {cx: common_true_idxs for cx in unique_pcxs}
        common_ious = pred_dets.boxes.ious(true_dets.boxes, bias=bias)
        # common_ious = pred_dets.boxes.ious(true_dets.boxes, impl='c', bias=bias)
        iou_lookup = dict(enumerate(common_ious))
    else:
        # For each pred-category find matchable true-indices
        cx_to_matchable_txs = {}
        for cx, compat_cx in cx_to_matchable_cxs.items():
            matchable_cxs = cx_to_matchable_cxs[cx]
            compat_txs = ub.dict_take(cx_to_txs, matchable_cxs, default=[])
            compat_txs = np.array(sorted(ub.flatten(compat_txs)), dtype=int)
            cx_to_matchable_txs[cx] = compat_txs

        # Batch up the IOU pre-computation between compatible truths / preds
        iou_lookup = {}
        unique_pred_cxs, pgroupxs = kwarray.group_indices(pred_dets.class_idxs)
        for cx, pred_idxs in zip(unique_pred_cxs, pgroupxs):
            true_idxs = cx_to_matchable_txs[cx]
            ious = pred_dets.boxes[pred_idxs].ious(true_dets.boxes[true_idxs],
                                                   bias=bias)
            _px_to_iou = dict(zip(pred_idxs, ious))
            iou_lookup.update(_px_to_iou)

    iou_thresh_list = ([iou_thresh]
                       if not ub.iterable(iou_thresh) else iou_thresh)

    iou_thresh_to_y = {}
    for iou_thresh_ in iou_thresh_list:
        isvalid_lookup = {
            px: ious > iou_thresh_
            for px, ious in iou_lookup.items()
        }

        y = _critical_loop(true_dets,
                           pred_dets,
                           iou_lookup,
                           isvalid_lookup,
                           cx_to_matchable_txs,
                           bg_weight,
                           prioritize,
                           iou_thresh_,
                           pdist_priority,
                           cx_to_ancestors,
                           bg_cidx,
                           ignore_classes=ignore_classes,
                           max_dets=max_dets)
        iou_thresh_to_y[iou_thresh_] = y

    if ub.iterable(iou_thresh):
        return iou_thresh_to_y
    else:
        return y
Exemple #11
0
def online_affine_perterb_np(np_images,
                             rng,
                             interp='cubic',
                             border_mode='reflect',
                             **kw):
    """
    Args:
        np_images (list) list of images to receive the same transform

    Exception:
        >>> from clab.augment.augment_numpy_online import *
        >>> import ubelt as ub
        >>> import numpy as np
        >>> import plottool as pt
        >>> rng = np.random
        >>> fpath = ub.grabdata('https://i.imgur.com/oHGsmvF.png', fname='carl.png')
        >>> img = imutil.imread(fpath)
        >>> np_images = [img]
        >>> kw = {}
        >>> imaug, = online_affine_perterb_np([img], rng)
        >>> pt.imshow(np.array(imaug))

    Ignore:
        Aff = affine_around_mat2x3(0, 0)
        matrix = np.array(Aff + [[0, 0, 1]])
        skaff = skimage.transform.AffineTransform(matrix=matrix)

        # CONCLUSION: this can handle n-channel images
        img2 = np.random.rand(24, 24, 5)
        imaug2 = skimage.transform.warp(
            img2, skaff, output_shape=img2.shape, order=0, mode='reflect',
            clip=True, preserve_range=True)

    """
    augkw = PERTERB_AUG_KW.copy()
    augkw.update(kw)
    affine_args = random_affine_args(rng=rng, **augkw)

    if not ub.iterable(interp):
        interps = [interp] * len(np_images)
    else:
        interps = interp
    assert len(interps) == len(np_images)

    for img, interp_ in zip(np_images, interps):
        h1, w1 = img.shape[0:2]
        x, y = w1 / 2, h1 / 2

        Aff = affine_around_mat2x3(x, y, *affine_args)
        matrix = np.array(Aff + [[0, 0, 1]])
        skaff = skimage.transform.AffineTransform(matrix=matrix)

        order = SKIMAGE_INTERP_LOOKUP[interp_]

        imaug = skimage.transform.warp(
            img,
            skaff,
            output_shape=img.shape,
            order=order,
            mode=border_mode,
            # cval=0.0,
            clip=True,
            preserve_range=True)
        imaug = imaug.astype(img.dtype)

        # imaug = cv2.warpAffine(
        #     img, Aff,
        #     dsize=(w1, h1),
        #     flags=cv2.INTER_LANCZOS4,
        #     borderMode=cv2.BORDER_REFLECT
        # )
        yield imaug
Exemple #12
0
def draw_boxes(boxes,
               alpha=None,
               color='blue',
               labels=None,
               centers=False,
               fill=False,
               ax=None,
               lw=2):
    """
    Args:
        boxes (kwimage.Boxes):
        labels (List[str]): of labels
        alpha (List[float]): alpha for each box
        centers (bool): draw centers or not
        lw (float): linewidth

    Example:
        >>> import kwimage
        >>> bboxes = kwimage.Boxes([[.1, .1, .6, .3], [.3, .5, .5, .6]], 'xywh')
        >>> draw_boxes(bboxes)
        >>> #kwplot.autompl()
    """
    import kwplot
    import matplotlib as mpl
    from matplotlib import pyplot as plt
    if ax is None:
        ax = plt.gca()

    xywh = boxes.to_xywh().data

    transparent = kwplot.Color((0, 0, 0, 0)).as01('rgba')

    # More grouped patches == more efficient runtime
    if alpha is None:
        alpha = [1.0] * len(xywh)
    elif not ub.iterable(alpha):
        alpha = [alpha] * len(xywh)

    edgecolors = [kwplot.Color(color, alpha=a).as01('rgba') for a in alpha]
    color_groups = ub.group_items(range(len(edgecolors)), edgecolors)
    for edgecolor, idxs in color_groups.items():
        if fill:
            fc = edgecolor
        else:
            fc = transparent
        rectkw = dict(ec=edgecolor, fc=fc, lw=lw, linestyle='solid')
        patches = [
            mpl.patches.Rectangle((x, y), w, h, **rectkw)
            for x, y, w, h in xywh[idxs]
        ]
        col = mpl.collections.PatchCollection(patches, match_original=True)
        ax.add_collection(col)

    if centers not in [None, False]:
        default_centerkw = {
            # 'radius': 1,
            'fill': True
        }
        centerkw = default_centerkw.copy()
        if isinstance(centers, dict):
            centerkw.update(centers)
        xy_centers = boxes.xy_center
        for fcolor, idxs in color_groups.items():
            # TODO: radius based on size of bbox
            # if 'radius' not in centerkw:
            #     boxes.area[idxs]

            patches = [
                mpl.patches.Circle((x, y), ec=None, fc=fcolor, **centerkw)
                for x, y in xy_centers[idxs]
            ]
            col = mpl.collections.PatchCollection(patches, match_original=True)
            ax.add_collection(col)

    if labels:
        texts = []
        default_textkw = {
            'horizontalalignment':
            'left',
            'verticalalignment':
            'top',
            'backgroundcolor': (0, 0, 0, .8),
            'color':
            'white',
            'fontproperties':
            mpl.font_manager.FontProperties(size=6, family='monospace'),
        }
        tkw = default_textkw.copy()
        for (x1, y1, w, h), label in zip(xywh, labels):
            texts.append((x1, y1, label, tkw))
        for (x1, y1, catname, tkw) in texts:
            ax.text(x1, y1, catname, **tkw)
Exemple #13
0
    def random_params(cls, rng=None, **kw):
        """
        Args:
            rng : random number generator
            **kw: can contain coercable random distributions for
                scale, offset, about, theta, and shear.

        Returns:
            Dict: affine parameters suitable to be passed to Affine.affine

        TODO:
            - [ ] improve kwargs parameterization
        """
        from kwarray import distributions
        import numbers
        TN = distributions.TruncNormal
        rng = kwarray.ensure_rng(rng)

        def _coerce_distri(arg):
            if isinstance(arg, numbers.Number):
                dist = distributions.Constant(arg, rng=rng)
            else:
                raise NotImplementedError
            return dist

        if 'scale' in kw:
            if ub.iterable(kw['scale']):
                raise NotImplementedError
            else:
                xscale_dist = _coerce_distri(kw['scale'])
                yscale_dist = xscale_dist
        else:
            scale_kw = dict(mean=1, std=1, low=1, high=2)
            xscale_dist = TN(**scale_kw, rng=rng)
            yscale_dist = TN(**scale_kw, rng=rng)

        if 'offset' in kw:
            if ub.iterable(kw['offset']):
                raise NotImplementedError
            else:
                xoffset_dist = _coerce_distri(kw['offset'])
                yoffset_dist = xoffset_dist
        else:
            offset_kw = dict(mean=0, std=1, low=-1, high=1)
            xoffset_dist = TN(**offset_kw, rng=rng)
            yoffset_dist = TN(**offset_kw, rng=rng)

        if 'about' in kw:
            if ub.iterable(kw['about']):
                raise NotImplementedError
            else:
                xabout_dist = _coerce_distri(kw['about'])
                yabout_dist = xabout_dist
        else:
            xabout_dist = distributions.Constant(0, rng=rng)
            yabout_dist = distributions.Constant(0, rng=rng)

        if 'theta' in kw:
            theta_dist = _coerce_distri(kw['theta'])
        else:
            theta_kw = dict(mean=0, std=1, low=-np.pi / 8, high=np.pi / 8)
            theta_dist = TN(**theta_kw, rng=rng)

        if 'shear' in kw:
            shear_dist = _coerce_distri(kw['shear'])
        else:
            shear_dist = distributions.Constant(0, rng=rng)

        # scale_kw = dict(mean=1, std=1, low=0, high=2)
        # offset_kw = dict(mean=0, std=1, low=-1, high=1)
        # theta_kw = dict(mean=0, std=1, low=-6.28, high=6.28)

        # TODO: distributions.Distribution.coerce()
        # offset_dist = distributions.Constant(0)
        # theta_dist = distributions.Constant(0)

        # todo better parametarization
        params = dict(
            scale=(xscale_dist.sample(), yscale_dist.sample()),
            offset=(xoffset_dist.sample(), yoffset_dist.sample()),
            theta=theta_dist.sample(),
            shear=shear_dist.sample(),
            about=(xabout_dist.sample(), yabout_dist.sample()),
        )
        return params
Exemple #14
0
    def draw(self, color='blue', ax=None, alpha=None, radius=1, **kwargs):
        """
        TODO: can use kwplot.draw_points

        Example:
            >>> # xdoc: +REQUIRES(module:kwplot)
            >>> from kwimage.structs.points import *  # NOQA
            >>> pts = Points.random(10)
            >>> # xdoc: +REQUIRES(--show)
            >>> pts.draw(radius=0.01)

            >>> from kwimage.structs.points import *  # NOQA
            >>> self = Points.random(10, classes=['a', 'b', 'c'])
            >>> self.draw(radius=0.01, color='classes')
        """
        import kwimage
        import matplotlib as mpl
        from matplotlib import pyplot as plt
        if ax is None:
            ax = plt.gca()
        xy = self.data['xy'].data.reshape(-1, 2)

        # More grouped patches == more efficient runtime
        if alpha is None:
            alpha = [1.0] * len(xy)
        elif not ub.iterable(alpha):
            alpha = [alpha] * len(xy)

        if color == 'distinct':
            colors = kwimage.Color.distinct(len(alpha))
        elif color == 'classes':
            # TODO: read colors from categories if they exist
            try:
                class_idxs = self.data['class_idxs']
                cls_colors = kwimage.Color.distinct(len(self.meta['classes']))
            except KeyError:
                raise Exception('cannot draw class colors without class_idxs and classes')
            _keys, _vals = kwarray.group_indices(class_idxs)
            colors = list(ub.take(cls_colors, class_idxs))
        else:
            colors = [color] * len(alpha)

        ptcolors = [kwimage.Color(c, alpha=a).as01('rgba')
                    for c, a in zip(colors, alpha)]
        color_groups = ub.group_items(range(len(ptcolors)), ptcolors)

        circlekw = {
            'radius': radius,
            'fill': True,
            'ec': None,
        }
        if 'fc' in kwargs:
            warnings.warning(
                'Warning: specifying fc to Points.draw overrides '
                'the color argument. Use color instead')
        circlekw.update(kwargs)
        fc = circlekw.pop('fc', None)  # hack

        collections = []
        for pcolor, idxs in color_groups.items():

            # hack for fc
            if fc is not None:
                pcolor = fc

            patches = [
                mpl.patches.Circle((x, y), fc=pcolor, **circlekw)
                for x, y in xy[idxs]
            ]
            col = mpl.collections.PatchCollection(patches, match_original=True)
            collections.append(col)
            ax.add_collection(col)
        return collections
Exemple #15
0
    def draw(self,
             color='blue',
             ax=None,
             alpha=None,
             coord_axes=[1, 0],
             radius=1,
             setlim=False):
        """
        Note:
            unlike other methods, the defaults assume x/y internal data

        Args:
            setlim (bool): if True ensures the limits of the axes contains the
                polygon

            coord_axes (Tuple): specify which image axes each coordinate dim
                corresponds to.  For 2D images,
                    if you are storing r/c data, set to [0,1],
                    if you are storing x/y data, set to [1,0].

        Returns:
            List[mpl.collections.PatchCollection]: drawn matplotlib objects

        Example:
            >>> # xdoc: +REQUIRES(module:kwplot)
            >>> from kwimage.structs.coords import *  # NOQA
            >>> self = Coords.random(10)
            >>> # xdoc: +REQUIRES(--show)
            >>> self.draw(radius=3.0, setlim=True)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> self.draw(radius=3.0)
        """
        import matplotlib as mpl
        import kwimage
        from matplotlib import pyplot as plt
        if ax is None:
            ax = plt.gca()
        data = self.data

        if self.dim != 2:
            raise NotImplementedError('need 2d for mpl')

        # More grouped patches == more efficient runtime
        if alpha is None:
            alpha = [1.0] * len(data)
        elif not ub.iterable(alpha):
            alpha = [alpha] * len(data)

        ptcolors = [kwimage.Color(color, alpha=a).as01('rgba') for a in alpha]
        color_groups = ub.group_items(range(len(ptcolors)), ptcolors)

        default_centerkw = {'radius': radius, 'fill': True}
        centerkw = default_centerkw.copy()
        collections = []
        for pcolor, idxs in color_groups.items():
            yx_list = [row[coord_axes] for row in data[idxs]]
            patches = [
                mpl.patches.Circle((x, y), ec=None, fc=pcolor, **centerkw)
                for y, x in yx_list
            ]
            col = mpl.collections.PatchCollection(patches, match_original=True)
            collections.append(col)
            ax.add_collection(col)

        if setlim:
            x1, y1 = self.data.min(axis=0)
            x2, y2 = self.data.max(axis=0)

            if setlim == 'grow':
                # only allow growth
                x1_, x2_ = ax.get_xlim()
                y1_, y2_ = ax.get_ylim()
                x1 = min(x1_, x1)
                x2 = max(x2_, x2)
                y1 = min(y1_, y1)
                y2 = max(y2_, y2)

            ax.set_xlim(x1, x2)
            ax.set_ylim(y1, y2)
        return collections
Exemple #16
0
    def __getitem__(self, index):
        """
        References:
            https://gis.stackexchange.com/questions/162095/gdal-driver-create-typeerror

        Ignore:
            >>> from ndsampler.utils.util_gdal import *  # NOQA
            >>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400))
            >>> index = [slice(2100, 2508, None), slice(4916, 5324, None), None]
            >>> img_part = self[index]
            >>> # xdoctest: +REQUIRES(--show)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> kwplot.imshow(img_part)
        """
        ds = self._ds
        width = ds.RasterXSize
        height = ds.RasterYSize
        C = ds.RasterCount

        if not ub.iterable(index):
            index = [index]

        index = list(index)
        if len(index) < 3:
            n = (3 - len(index))
            index = index + [None] * n

        ypart = _rectify_slice_dim(index[0], height)
        xpart = _rectify_slice_dim(index[1], width)
        channel_part = _rectify_slice_dim(index[2], C)
        trailing_part = [channel_part]

        if len(trailing_part) == 1:
            channel_part = trailing_part[0]
            rb_indices = range(*channel_part.indices(C))
        else:
            rb_indices = range(C)
            assert len(trailing_part) <= 1

        ystart, ystop = map(int, [ypart.start, ypart.stop])
        xstart, xstop = map(int, [xpart.start, xpart.stop])

        ysize = ystop - ystart
        xsize = xstop - xstart

        gdalkw = dict(xoff=xstart,
                      yoff=ystart,
                      win_xsize=xsize,
                      win_ysize=ysize)

        PREALLOC = 1
        if PREALLOC:
            # preallocate like kwimage.im_io._imread_gdal
            from kwimage.im_io import _gdal_to_numpy_dtype
            shape = (ysize, xsize, len(rb_indices))
            bands = [ds.GetRasterBand(1 + rb_idx) for rb_idx in rb_indices]
            gdal_dtype = bands[0].DataType
            dtype = _gdal_to_numpy_dtype(gdal_dtype)
            img_part = np.empty(shape, dtype=dtype)
            for out_idx, rb in enumerate(bands):
                img_part[:, :, out_idx] = rb.ReadAsArray(**gdalkw)
        else:
            channels = []
            for rb_idx in rb_indices:
                rb = ds.GetRasterBand(1 + rb_idx)
                channel = rb.ReadAsArray(**gdalkw)
                channels.append(channel)
            img_part = np.dstack(channels)
        return img_part
Exemple #17
0
    def load(self, data=None, cmdline=False, mode=None, default=None):
        """
        Updates the default configuration from a given data source.

        Any option can be overwritten via the command line if `cmdline` is
        truthy.

        Args:
            data (PathLike | dict):
                Either a path to a yaml / json file or a config dict

            cmdline (bool | List[str] | str, default=False):
                If False, then no command line information is used.
                If True, then sys.argv is parsed and used.
                If a list of strings that used instead of sys.argv.
                If a string, then that is parsed using shlex and used instead
                    of sys.argv.

        Example:
            >>> # Test load works correctly in cmdline True and False mode
            >>> import scriptconfig as scfg
            >>> class MyConfig(scfg.Config):
            >>>     default = {
            >>>         'src': scfg.Value(None, help=('some help msg')),
            >>>     }
            >>> data = {'src': 'hi'}
            >>> self = MyConfig(data=data, cmdline=False)
            >>> assert self['src'] == 'hi'
            >>> self = MyConfig(default=data, cmdline=True)
            >>> assert self['src'] == 'hi'
            >>> # In 0.5.8 and previous src fails to populate!
            >>> # This is because cmdline=True overwrites data with defaults
            >>> self = MyConfig(data=data, cmdline=True)
            >>> assert self['src'] == 'hi'

        """
        if default:
            self.update_defaults(default)

        # Maybe this shouldn't be a deep copy?
        _default = copy.deepcopy(self._default)

        if mode is None:
            if isinstance(data, six.string_types):
                if data.lower().endswith('.json'):
                    mode = 'json'
        if mode is None:
            # Default to yaml
            mode = 'yaml'

        if data is None:
            user_config = {}
        elif isinstance(data, six.string_types) or hasattr(data, 'readable'):
            with FileLike(data, 'r') as file:
                user_config = yaml.load(file, Loader=yaml.SafeLoader)
            user_config.pop('__heredoc__', None)  # ignore special heredoc key
        elif isinstance(data, dict):
            user_config = data
        elif scfg_isinstance(data, Config):
            user_config = data.asdict()
        else:
            raise TypeError(
                'Expected path or dict, but got {}'.format(type(data)))

        # check for unknown values
        unknown_keys = set(user_config) - set(_default)
        if unknown_keys:
            raise KeyError('Unknown data options {}'.format(unknown_keys))

        self._data = _default.copy()
        self.update(user_config)

        if isinstance(cmdline, six.string_types):
            # allow specification using the actual command line arg string
            import shlex
            import os
            cmdline = shlex.split(os.path.expandvars(cmdline))

        if cmdline or ub.iterable(cmdline):
            # TODO: if user_config is specified, then we should probably not
            # override any values in user_config with the defaults? The CLI
            # should override them IF they exist on in sys.argv, but not if
            # they don't?
            argv = cmdline if ub.iterable(cmdline) else None
            self._read_argv(argv=argv)

        self.normalize()
        return self
Exemple #18
0
def draw_points(xy,
                color='blue',
                class_idxs=None,
                classes=None,
                ax=None,
                alpha=None,
                radius=1,
                **kwargs):
    """

    Args:
        xy (ndarray): of points.

    Example:
        >>> from kwplot.mpl_draw import *  # NOQA
        >>> import kwimage
        >>> xy = kwimage.Points.random(10).xy
        >>> draw_points(xy, radius=0.01)
        >>> draw_points(xy, class_idxs=np.random.randint(0, 3, 10),
        >>>         radius=0.01, classes=['a', 'b', 'c'], color='classes')

    Ignore:
        >>> import kwplot
        >>> kwplot.autompl()
    """
    import kwimage
    import matplotlib as mpl
    from matplotlib import pyplot as plt
    if ax is None:
        ax = plt.gca()

    xy = xy.reshape(-1, 2)

    # More grouped patches == more efficient runtime
    if alpha is None:
        alpha = [1.0] * len(xy)
    elif not ub.iterable(alpha):
        alpha = [alpha] * len(xy)

    if color == 'distinct':
        colors = kwimage.Color.distinct(len(alpha))
    elif color == 'classes':
        # TODO: read colors from categories if they exist
        if class_idxs is None or classes is None:
            raise Exception(
                'cannot draw class colors without class_idxs and classes')
        try:
            cls_colors = kwimage.Color.distinct(len(classes))
        except KeyError:
            raise Exception(
                'cannot draw class colors without class_idxs and classes')
        import kwarray
        _keys, _vals = kwarray.group_indices(class_idxs)
        colors = list(ub.take(cls_colors, class_idxs))
    else:
        colors = [color] * len(alpha)

    ptcolors = [
        kwimage.Color(c, alpha=a).as01('rgba') for c, a in zip(colors, alpha)
    ]
    color_groups = ub.group_items(range(len(ptcolors)), ptcolors)

    circlekw = {
        'radius': radius,
        'fill': True,
        'ec': None,
    }
    if 'fc' in kwargs:
        import warnings
        warnings.warning('Warning: specifying fc to Points.draw overrides '
                         'the color argument. Use color instead')
    circlekw.update(kwargs)
    fc = circlekw.pop('fc', None)  # hack

    collections = []
    for pcolor, idxs in color_groups.items():

        # hack for fc
        if fc is not None:
            pcolor = fc

        patches = [
            mpl.patches.Circle((x, y), fc=pcolor, **circlekw)
            for x, y in xy[idxs]
        ]
        col = mpl.collections.PatchCollection(patches, match_original=True)
        collections.append(col)
        ax.add_collection(col)
    return collections
Exemple #19
0
    def argparse(self, parser=None, special_options=False):
        """
        construct or update an argparse.ArgumentParser CLI parser

        Args:
            parser (None | argparse.ArgumentParser): if specified this
                parser is updated with options from this config.

            special_options (bool, default=False):
                adds special scriptconfig options, namely: --config, --dumps,
                and --dump.

        Returns:
            argparse.ArgumentParser : a new or updated argument parser

        CommandLine:
            xdoctest -m scriptconfig.config Config.argparse:0
            xdoctest -m scriptconfig.config Config.argparse:1

        TODO:
            A good CLI spec for lists might be

            # In the case where ``key`` ends with and ``=``, assume the list is
            # given as a comma separated string with optional square brakets at
            # each end.

            --key=[f]

            # In the case where ``key`` does not end with equals and we know
            # the value is supposd to be a list, then we consume arguments
            # until we hit the next one that starts with '--' (which means
            # that list items cannot start with -- but they can contains
            # commas)

        FIXME:

            * In the case where we have an nargs='+' action, and we specify
              the option with an `=`, and then we give position args after it
              there is no way to modify behavior of the action to just look at
              the data in the string without modifying the ArgumentParser
              itself. The action object has no control over it. For example
              `--foo=bar baz biz` will parse as `[baz, biz]` which is really
              not what we want. We may be able to overload ArgumentParser to
              fix this.

        Example:
            >>> # You can now make instances of this class
            >>> import scriptconfig
            >>> self = scriptconfig.Config.demo()
            >>> parser = self.argparse()
            >>> parser.print_help()
            >>> # xdoctest: +REQUIRES(PY3)
            >>> # Python2 argparse does a hard sys.exit instead of raise
            >>> ns, extra = parser.parse_known_args()

        Example:
            >>> # You can now make instances of this class
            >>> import scriptconfig as scfg
            >>> class MyConfig(scfg.Config):
            >>>     description = 'my CLI description'
            >>>     default = {
            >>>         'path1':  scfg.Value(None, position=1, alias='src'),
            >>>         'path2':  scfg.Value(None, position=2, alias='dst'),
            >>>         'dry':  scfg.Value(False, isflag=True),
            >>>         'approx':  scfg.Value(False, isflag=False, alias=['a1', 'a2']),
            >>>     }
            >>> self = MyConfig()
            >>> special_options = True
            >>> parser = None
            >>> parser = self.argparse(special_options=special_options)
            >>> parser.print_help()
            >>> self._read_argv(argv=['objection', '42', '--path1=overruled!'])
            >>> print('self = {!r}'.format(self))

        Ignore:
            >>> self._read_argv(argv=['hi','--path1=foobar'])
            >>> self._read_argv(argv=['hi', 'hello', '--path1=foobar'])
            >>> self._read_argv(argv=['hi', 'hello', '--path1=foobar', '--help'])
            >>> self._read_argv(argv=['--path1=foobar', '--path1=baz'])
            >>> print('self = {!r}'.format(self))
        """
        import argparse

        if parser is None:
            parserkw = self._parserkw()
            parser = argparse.ArgumentParser(**parserkw)

        # Use custom action used to mark which values were explicitly set on
        # the commandline
        parser._explicitly_given = set()

        parent = self

        class ParseAction(argparse.Action):
            def __init__(self, *args, **kwargs):
                super(ParseAction, self).__init__(*args, **kwargs)
                # with script config nothing should be required by default all
                # positional arguments should have keyword arg variants Setting
                # required=False here will prevent positional args from
                # erroring if they are not specified. I dont think there are
                # other side effects, but we should make sure that is actually
                # the case.
                self.required = False

                if self.type is None:
                    # Is this the right place to put this?
                    def _mytype(value):
                        key = self.dest
                        template = parent.default[key]
                        if not isinstance(template, Value):
                            # smartcast non-valued params from commandline
                            value = smartcast.smartcast(value)
                        else:
                            value = template.cast(value)
                        return value

                    self.type = _mytype

                # print('self.type = {!r}'.format(self.type))

            def __call__(action, parser, namespace, values, option_string=None):
                # print('CALL action = {!r}'.format(action))
                # print('option_string = {!r}'.format(option_string))
                # print('values = {!r}'.format(values))

                if isinstance(values, list) and len(values):
                    # We got a list of lists, which we hack into a flat list
                    if isinstance(values[0], list):
                        import itertools as it
                        values = list(it.chain(*values))

                setattr(namespace, action.dest, values)
                parser._explicitly_given.add(action.dest)

        # IRC: this ensures each key has a real Value class
        _metadata = {
            key: self._data[key]
            for key, value in self._default.items()
            if isinstance(self._data[key], Value)
        }  # :type: Dict[str, Value]
        _positions = {k: v.position for k, v in _metadata.items()
                      if v.position is not None}
        if _positions:
            if ub.find_duplicates(_positions.values()):
                raise Exception('two values have the same position')
            _keyorder = ub.oset(ub.argsort(_positions))
            _keyorder |= (ub.oset(self._default) - _keyorder)
        else:
            _keyorder = list(self._default.keys())

        def _add_arg(parser, name, key, argkw, positional, isflag, isalias):
            _argkw = argkw.copy()

            if isalias:
                _argkw['help'] = 'alias of {}'.format(key)
                _argkw.pop('default', None)
                # flags cannot have flag aliases
                isflag = False

            elif positional:
                parser.add_argument(name, **_argkw)

            if isflag:
                # Can we support both flag and setitem methods of cli
                # parsing?
                if not isinstance(_argkw.get('default', None), bool):
                    raise ValueError('can only use isflag with bools')
                _argkw.pop('type', None)
                _argkw.pop('choices', None)
                _argkw.pop('action', None)
                _argkw.pop('nargs', None)
                _argkw['dest'] = key

                _argkw_true = _argkw.copy()
                _argkw_true['action'] = 'store_true'

                _argkw_false = _argkw.copy()
                _argkw_false['action'] = 'store_false'
                _argkw_false.pop('help', None)

                parser.add_argument('--' + name, **_argkw_true)
                parser.add_argument('--no-' + name, **_argkw_false)
            else:
                parser.add_argument('--' + name, **_argkw)

        mode = 1

        alias_registry = []
        for key, value in self._data.items():
            # key: str
            # value: Any | Value
            argkw = {}
            argkw['help'] = ''
            positional = None
            isflag = False
            if key in _metadata:
                # Use the metadata in the Value class to enhance argparse
                _value = _metadata[key]
                argkw.update(_value.parsekw)
                value = _value.value
                isflag = _value.isflag
                positional = _value.position
            else:
                _value = value if isinstance(value, Value) else None

            if not argkw['help']:
                argkw['help'] = '<undocumented>'

            argkw['default'] = value
            argkw['action'] = ParseAction

            name = key
            _add_arg(parser, name, key, argkw, positional, isflag, isalias=False)

            if _value is not None:
                if _value.alias:
                    alts = _value.alias
                    alts = alts if ub.iterable(alts) else [alts]
                    for alias in alts:
                        tup = (alias, key, argkw)
                        alias_registry.append(tup)
                        if mode == 0:
                            name = alias
                            _add_arg(parser, name, key, argkw, positional, isflag, isalias=True)

        if mode == 1:
            for tup in alias_registry:
                (alias, key, argkw) = tup
                name = alias
                dest = key
                _add_arg(parser, name, dest, argkw, positional, isflag, isalias=True)

        if special_options:
            parser.add_argument('--config', default=None, help=ub.codeblock(
                '''
                special scriptconfig option that accepts the path to a on-disk
                configuration file, and loads that into this {!r} object.
                ''').format(self.__class__.__name__))

            parser.add_argument('--dump', default=None, help=ub.codeblock(
                '''
                If specified, dump this config to disk.
                ''').format(self.__class__.__name__))

            parser.add_argument('--dumps', action='store_true', help=ub.codeblock(
                '''
                If specified, dump this config stdout
                ''').format(self.__class__.__name__))

        return parser
Exemple #20
0
def _dump_measures(tb_data,
                   out_dpath,
                   mode=None,
                   smoothing=0.0,
                   ignore_outliers=True):
    """
    This is its own function in case we need to modify formatting

    CommandLine:
        xdoctest -m netharn.mixins _dump_measures --out_dpath=.

    Example:
        >>> # SCRIPT
        >>> # Reread a dumped pickle file
        >>> from netharn.mixins import *  # NOQA
        >>> from netharn.mixins import _dump_monitor_tensorboard, _dump_measures
        >>> import json
        >>> from os.path import join
        >>> import ubelt as ub
        >>> try:
        >>>     import seaborn as sns
        >>>     sns.set()
        >>> except ImportError:
        >>>     pass
        >>> out_dpath = ub.expandpath('~/work/project/fit/nice/nicename/monitor/tensorboard/')
        >>> out_dpath = ub.argval('--out_dpath', default=out_dpath)
        >>> mode = ['epoch', 'iter']
        >>> fpath = join(out_dpath, 'tb_data.json')
        >>> tb_data = json.load(open(fpath, 'r'))
        >>> import kwplot
        >>> kwplot.autompl()
        >>> _dump_measures(tb_data,  out_dpath, smoothing=0)
    """
    import ubelt as ub
    from os.path import join
    import numpy as np
    import kwplot
    import matplotlib as mpl
    from kwplot.auto_backends import BackendContext

    with BackendContext('agg'):
        # kwplot.autompl()

        # TODO: Is it possible to get htop to show this process with some name that
        # distinguishes it from the dataloader workers?
        # import sys
        # import multiprocessing
        # if multiprocessing.current_process().name != 'MainProcess':
        #     if sys.platform.startswith('linux'):
        #         import ctypes
        #         libc = ctypes.cdll.LoadLibrary('libc.so.6')
        #         title = 'Netharn MPL Dump Measures'
        #         libc.prctl(len(title), title, 0, 0, 0)

        # NOTE: This cause warnings when exeucted as daemon process
        # try:
        #     import seaborn as sbn
        #     sbn.set()
        # except ImportError:
        #     pass

        valid_modes = ['epoch', 'iter']
        if mode is None:
            mode = valid_modes
        if ub.iterable(mode):
            # Hack: Call with all modes
            for mode_ in mode:
                _dump_measures(tb_data,
                               out_dpath,
                               mode=mode_,
                               smoothing=smoothing,
                               ignore_outliers=ignore_outliers)
            return
        else:
            assert mode in valid_modes

        meta = tb_data.get('meta', {})
        nice = meta.get('nice', '?nice?')
        special_groupers = meta.get('special_groupers', ['loss'])

        fig = kwplot.figure(fnum=1)

        plot_keys = [
            key for key in tb_data
            if ('train_' + mode in key or 'vali_' + mode in key or 'test_' +
                mode in key or mode + '_' in key)
        ]
        y01_measures = [
            '_acc',
            '_ap',
            '_mAP',
            '_auc',
            '_mcc',
            '_brier',
            '_mauc',
        ]
        y0_measures = ['error', 'loss']

        keys = set(tb_data.keys()).intersection(set(plot_keys))

        # print('mode = {!r}'.format(mode))
        # print('tb_data.keys() = {!r}'.format(tb_data.keys()))
        # print('plot_keys = {!r}'.format(plot_keys))
        # print('keys = {!r}'.format(keys))

        def smooth_curve(ydata, beta):
            """
            Curve smoothing algorithm used by tensorboard
            """
            import pandas as pd
            alpha = 1.0 - beta
            if alpha <= 0:
                return ydata
            ydata_smooth = pd.Series(ydata).ewm(alpha=alpha).mean().values
            return ydata_smooth

        def inlier_ylim(ydatas):
            """
            outlier removal used by tensorboard
            """
            low, high = None, None
            for ydata in ydatas:
                q1 = 0.05
                q2 = 0.95
                low_, high_ = np.quantile(ydata, [q1, q2])

                # Extrapolate how big the entire span should be based on inliers
                inner_q = q2 - q1
                inner_extent = high_ - low_
                extrap_total_extent = inner_extent / inner_q

                # amount of padding to add to either side
                missing_p1 = q1
                missing_p2 = 1 - q2
                frac1 = missing_p1 / (missing_p2 + missing_p1)
                frac2 = missing_p2 / (missing_p2 + missing_p1)
                missing_extent = extrap_total_extent - inner_extent

                pad1 = missing_extent * frac1
                pad2 = missing_extent * frac2

                low_ = low_ - pad1
                high_ = high_ + pad2

                low = low_ if low is None else min(low_, low)
                high = high_ if high is None else max(high_, high)
            return (low, high)

        # Hack values that we don't apply smoothing to
        HACK_NO_SMOOTH = ['lr', 'momentum']

        def tag_grouper(k):
            # parts = ['train_epoch', 'vali_epoch', 'test_epoch']
            # parts = [p.replace('epoch', 'mode') for p in parts]
            parts = [p + mode for p in ['train_', 'vali_', 'test_']]
            for p in parts:
                if p in k:
                    return p.split('_')[0]
            return 'unknown'

        GROUP_LOSSES = True
        GROUP_AND_INDIVIDUAL = False
        INDIVIDUAL_PLOTS = True
        GROUP_SPECIAL = True

        if GROUP_LOSSES:
            # Group all losses in one plot for comparison
            loss_keys = [k for k in keys if 'loss' in k]
            tagged_losses = ub.group_items(loss_keys, tag_grouper)
            tagged_losses.pop('unknown', None)
            kw = {}
            kw['ymin'] = 0.0
            # print('tagged_losses = {!r}'.format(tagged_losses))
            for tag, losses in tagged_losses.items():

                min_abs_y = .01
                min_y = 0
                xydata = ub.odict()
                for key in sorted(losses):
                    ydata = tb_data[key]['ydata']

                    if HACK_NO_SMOOTH not in key.split('_'):
                        ydata = smooth_curve(ydata, smoothing)

                    try:
                        min_y = min(min_y, ydata.min())
                        pos_ys = ydata[ydata > 0]
                        min_abs_y = min(min_abs_y, pos_ys.min())
                    except Exception:
                        pass

                    xydata[key] = (tb_data[key]['xdata'], ydata)

                kw['ymin'] = min_y

                if ignore_outliers:
                    low, kw['ymax'] = inlier_ylim(
                        [t[1] for t in xydata.values()])

                yscales = ['symlog', 'linear']
                for yscale in yscales:
                    fig.clf()
                    ax = fig.gca()
                    title = nice + '\n' + tag + '_' + mode + ' losses'
                    kwplot.multi_plot(xydata=xydata,
                                      ylabel='loss',
                                      xlabel=mode,
                                      yscale=yscale,
                                      title=title,
                                      fnum=1,
                                      ax=ax,
                                      **kw)
                    if yscale == 'symlog':
                        if LooseVersion(
                                mpl.__version__) >= LooseVersion('3.3'):
                            ax.set_yscale('symlog', linthresh=min_abs_y)
                        else:
                            ax.set_yscale('symlog', linthreshy=min_abs_y)
                    fname = '_'.join([tag, mode, 'multiloss', yscale]) + '.png'
                    fpath = join(out_dpath, fname)
                    ax.figure.savefig(fpath)

            # don't dump losses individually if we dump them in a group
            if not GROUP_AND_INDIVIDUAL:
                keys.difference_update(set(loss_keys))
                # print('keys = {!r}'.format(keys))

        if GROUP_SPECIAL:
            tag_groups = ub.group_items(keys, tag_grouper)
            tag_groups.pop('unknown', None)
            # Group items matching these strings
            kw = {}
            for tag, tag_keys in tag_groups.items():
                for groupname in special_groupers:
                    group_keys = [
                        k for k in tag_keys if groupname in k.split('_')
                    ]
                    if len(group_keys) > 1:
                        # Gather data for this group
                        xydata = ub.odict()
                        for key in sorted(group_keys):
                            ydata = tb_data[key]['ydata']
                            if HACK_NO_SMOOTH not in key.split('_'):
                                ydata = smooth_curve(ydata, smoothing)
                            xydata[key] = (tb_data[key]['xdata'], ydata)

                        if ignore_outliers:
                            low, kw['ymax'] = inlier_ylim(
                                [t[1] for t in xydata.values()])

                        yscales = ['linear']
                        for yscale in yscales:
                            fig.clf()
                            ax = fig.gca()
                            title = nice + '\n' + tag + '_' + mode + ' ' + groupname
                            kwplot.multi_plot(xydata=xydata,
                                              ylabel=groupname,
                                              xlabel=mode,
                                              yscale=yscale,
                                              title=title,
                                              fnum=1,
                                              ax=ax,
                                              **kw)
                            if yscale == 'symlog':
                                ax.set_yscale('symlog', linthreshy=min_abs_y)
                            fname = '_'.join([
                                tag, mode, 'group-' + groupname, yscale
                            ]) + '.png'
                            fpath = join(out_dpath, fname)
                            ax.figure.savefig(fpath)

                        if not GROUP_AND_INDIVIDUAL:
                            keys.difference_update(set(group_keys))

        if INDIVIDUAL_PLOTS:
            # print('keys = {!r}'.format(keys))
            for key in keys:
                d = tb_data[key]

                ydata = d['ydata']
                ydata = smooth_curve(ydata, smoothing)

                kw = {}
                if any(m.lower() in key.lower() for m in y01_measures):
                    kw['ymin'] = 0.0
                    kw['ymax'] = 1.0
                elif any(m.lower() in key.lower() for m in y0_measures):
                    kw['ymin'] = min(0.0, ydata.min())
                    if ignore_outliers:
                        low, kw['ymax'] = inlier_ylim([ydata])

                # NOTE: this is actually pretty slow
                fig.clf()
                ax = fig.gca()
                title = nice + '\n' + key
                kwplot.multi_plot(d['xdata'],
                                  ydata,
                                  ylabel=key,
                                  xlabel=mode,
                                  title=title,
                                  fnum=1,
                                  ax=ax,
                                  **kw)

                # png is slightly smaller than jpg for this kind of plot
                fpath = join(out_dpath, key + '.png')
                # print('save fpath = {!r}'.format(fpath))
                ax.figure.savefig(fpath)
def ensure_array_nd(data, n):
    if ub.iterable(data):
        return np.array(data)
    else:
        return np.array([data] * n)
Exemple #22
0
def _filter_ignore_regions(true_dets,
                           pred_dets,
                           ioaa_thresh=0.5,
                           ignore_classes='ignore'):
    """
    Determine which true and predicted detections should be ignored.

    Args:

        true_dets (Detections)

        pred_dets (Detections)

        ioaa_thresh (float): intersection over other area thresh for ignoring
            a region.

    Returns:
        Tuple[ndarray, ndarray]: flags indicating which true and predicted
            detections should be ignored.

    Example:
        >>> from kwcoco.metrics.assignment import *  # NOQA
        >>> from kwcoco.metrics.assignment import _filter_ignore_regions
        >>> import kwimage
        >>> pred_dets = kwimage.Detections.random(classes=['a', 'b', 'c'])
        >>> true_dets = kwimage.Detections.random(
        >>>     segmentations=True, classes=['a', 'b', 'c', 'ignore'])
        >>> ignore_classes = {'ignore', 'b'}
        >>> ioaa_thresh = 0.5
        >>> print('true_dets = {!r}'.format(true_dets))
        >>> print('pred_dets = {!r}'.format(pred_dets))
        >>> flags1, flags2 = _filter_ignore_regions(
        >>>     true_dets, pred_dets, ioaa_thresh=ioaa_thresh, ignore_classes=ignore_classes)
        >>> print('flags1 = {!r}'.format(flags1))
        >>> print('flags2 = {!r}'.format(flags2))

        >>> flags3, flags4 = _filter_ignore_regions(
        >>>     true_dets, pred_dets, ioaa_thresh=ioaa_thresh,
        >>>     ignore_classes={c.upper() for c in ignore_classes})
        >>> assert np.all(flags1 == flags3)
        >>> assert np.all(flags2 == flags4)
    """
    true_ignore_flags = np.zeros(len(true_dets), dtype=np.bool)
    pred_ignore_flags = np.zeros(len(pred_dets), dtype=np.bool)

    if not ub.iterable(ignore_classes):
        ignore_classes = {ignore_classes}

    def _normalize_catname(name, classes):
        if classes is None:
            return name
        if name in classes:
            return name
        for cname in classes:
            if cname.lower() == name.lower():
                return cname
        return name

    ignore_classes = {
        _normalize_catname(c, true_dets.classes)
        for c in ignore_classes
    }

    if true_dets.classes is not None:
        ignore_classes = ignore_classes & set(true_dets.classes)

    # Filter out true detections labeled as "ignore"
    if true_dets.classes is not None and ignore_classes:
        import kwarray
        ignore_cidxs = [true_dets.classes.index(c) for c in ignore_classes]
        true_ignore_flags = kwarray.isect_flags(true_dets.class_idxs,
                                                ignore_cidxs)

        if np.any(true_ignore_flags) and len(pred_dets):
            ignore_dets = true_dets.compress(true_ignore_flags)

            pred_boxes = pred_dets.data['boxes']
            ignore_boxes = ignore_dets.data['boxes']
            ignore_sseg = ignore_dets.data.get('segmentations', None)

            # Determine which predicted boxes are inside the ignore regions
            # note: using sum over max is delibrate here.
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore', message='invalid .* less')
                warnings.filterwarnings('ignore',
                                        message='invalid .* greater_equal')
                warnings.filterwarnings('ignore',
                                        message='invalid .* true_divide')
                ignore_overlap = (pred_boxes.isect_area(ignore_boxes) /
                                  pred_boxes.area).clip(0, 1).sum(axis=1)
                ignore_overlap = np.nan_to_num(ignore_overlap)

            ignore_idxs = np.where(ignore_overlap > ioaa_thresh)[0]

            if ignore_sseg is not None:
                from shapely.ops import cascaded_union
                # If the ignore region has segmentations further refine our
                # estimate of which predictions should be ignored.
                ignore_sseg = ignore_sseg.to_polygon_list()
                box_polys = ignore_boxes.to_polygons()
                ignore_polys = [
                    bp if p is None else p
                    for bp, p in zip(box_polys, ignore_sseg.data)
                ]
                ignore_regions = [p.to_shapely() for p in ignore_polys]
                ignore_region = cascaded_union(ignore_regions).buffer(0)

                cand_pred = pred_boxes.take(ignore_idxs)

                # Refine overlap estimates
                cand_regions = cand_pred.to_shapley()
                for idx, pred_region in zip(ignore_idxs, cand_regions):
                    try:
                        isect = ignore_region.intersection(pred_region)
                        overlap = (isect.area / pred_region.area)
                        ignore_overlap[idx] = overlap
                    except Exception as ex:
                        warnings.warn('ex = {!r}'.format(ex))
            pred_ignore_flags = ignore_overlap > ioaa_thresh
    return true_ignore_flags, pred_ignore_flags
    def variety_selection(sampler, num=20):
        import numpy as np
        dset = sampler.dset

        gid_to_props = ub.odict()
        for gid, img in dset.imgs.items():
            aids = dset.gid_to_aids[gid]
            annot_types = frozenset(dset.anns[aid]['roi_shape']
                                    for aid in aids)
            annot_cids = frozenset(dset.anns[aid]['category_id']
                                   for aid in aids)
            gid_to_props[gid] = ub.odict()
            gid_to_props[gid]['num_aids'] = len(aids)
            gid_to_props[gid]['annot_types'] = annot_types
            gid_to_props[gid]['annot_cats'] = annot_cids
            gid_to_props[gid]['orig_dset'] = frozenset([img['orig_dset']])

            try:
                from datetime import datetime
                datetime_object = datetime.strptime(img['date'],
                                                    '%Y-%m-%d %H:%M:%S')
            except Exception as ex:
                print('failed to parse time: {}'.format(img.get('date', None)))
                gid_to_props[gid]['time'] = None
            else:
                gid_to_props[gid]['time'] = datetime_object.toordinal()
                # 735858 + np.random.randn()

        if True:
            # Handle items without a parsable time
            all_ts = []
            for p in gid_to_props.values():
                if p['time'] is not None:
                    all_ts.append(p['time'])
            if len(all_ts) == 0:
                all_ts = [735857, 735859, 735850]
            all_ts = np.array(all_ts)
            mean_t = all_ts.mean()
            std_t = all_ts.std()
            for p in gid_to_props.values():
                if p['time'] is None:
                    p['time'] = mean_t + np.random.randn() * std_t

        basis_values = ub.ddict(set)
        for gid, props in gid_to_props.items():
            for key, value in props.items():
                if ub.iterable(value):
                    basis_values[key].update(value)

        basis_values = ub.map_vals(sorted, basis_values)

        # Build a descriptor to find a "variety" of images
        gid_to_desc = {}
        for gid, props in gid_to_props.items():
            desc = []
            for key, value in props.items():
                if ub.iterable(value):
                    hotvec = np.zeros(len(basis_values[key]))
                    for v in value:
                        idx = basis_values[key].index(v)
                        hotvec[idx] = 1
                    desc.append(hotvec)
                else:
                    desc.append([value])
            gid_to_desc[gid] = list(ub.flatten(desc))

        gids = np.array(list(gid_to_desc.keys()))
        vecs = np.array(list(gid_to_desc.values()))

        from sklearn import cluster
        algo = cluster.KMeans(n_clusters=num,
                              n_init=20,
                              max_iter=10000,
                              tol=1e-6,
                              algorithm='elkan',
                              verbose=0)
        algo.fit(vecs)
        algo.cluster_centers_

        assignment = algo.predict(vecs)
        grouped_gids = ub.group_items(gids, assignment)

        gid_list = [nh.util.shuffle(gids)[0] for gids in grouped_gids.values()]
        return gid_list