Пример #1
0
    def plot_cpd(self,
                 condition: SingleCategoryPTMixin,
                 max_pct: float = 1.0,
                 ax: Optional[Axes] = None,
                 **kwargs):
        """
        Plot distributions of answers to the question, conditioned on the
        values of discrete distribution `condition`.
        """
        axf = AxesFormatter(axes=ax)
        a_ix = self._data.index
        if max_pct < 1:
            a_ix = self._data.loc[
                self._data <= self._data.quantile(max_pct)].index
        a = self._data.loc[a_ix]
        condition_data = condition.data.loc[a_ix]
        kdeplot(x=a, hue=condition_data, ax=ax, **kwargs)

        # for condition_value in condition.category_names:
        #     condition_data = condition.data.loc[a_ix]
        #     condition_ix = (condition_data == condition_value).index
        #     kdeplot(x=a.loc[condition_ix],
        #             hue=condition_data[condition_ix],
        #             ax=ax, label=condition_value, **kwargs)
        # distplot(a=a.loc[condition_ix],
        #          label=condition_value,
        #          rug=False, kde=True, hist=False,
        #          ax=axf.axes, **kwargs)
        axf.set_text(x_label=self.name,
                     y_label=f'P({self.name}|{condition.name})')
        return axf.axes
Пример #2
0
    def draw(self, t: float, axes: AxesFormatter):

        kwargs = {}
        for float_kwarg in ('x_center', 'y_center', 'radius', 'theta_start',
                            'theta_end', 'width', 'alpha', 'line_width'):
            self._add_float_kwarg(float_kwarg, kwargs, t)
        for color_kwarg in ('color', 'edge_color', 'face_color'):
            self._add_color_kwarg(color_kwarg, kwargs, t)
        for kwarg in ('cap_style', 'fill', 'join_style', 'label',
                      'line_style'):
            self._add_non_animated_kwarg(kwarg, kwargs)

        axes.add_wedge(**kwargs)
Пример #3
0
    def draw(self, t: float, axes: AxesFormatter):

        kwargs = {}
        for float_kwarg in ('x_center', 'y_center', 'radius', 'angle', 'alpha',
                            'line_width'):
            self._add_float_kwarg(float_kwarg, kwargs, t)
        for color_kwarg in ('color', 'edge_color', 'face_color'):
            self._add_color_kwarg(color_kwarg, kwargs, t)
        for kwarg in ('num_vertices', 'fill', 'label', 'line_style',
                      'cap_style', 'join_style'):
            self._add_non_animated_kwarg(kwarg, kwargs)

        axes.add_regular_polygon(**kwargs)
def draw_tree(decision_tree: DecisionTree):

    axf = AxesFormatter(width=12, height=16)
    decision_tree.draw(node_labels='name', ax=axf.axes)
    axf.show()
    axf = AxesFormatter(width=12, height=16)
    decision_tree.draw(node_labels='amount', ax=axf.axes)
    axf.show()
Пример #5
0
    def draw(self, t: float, axes: AxesFormatter):

        kwargs = {}
        for float_kwarg in (
                'x_tail', 'y_tail', 'dx', 'dy',
                'width', 'alpha', 'line_width',
        ):
            self._add_float_kwarg(float_kwarg, kwargs, t)
        for color_kwarg in ('color', 'edge_color', 'face_color'):
            self._add_color_kwarg(color_kwarg, kwargs, t)
        for kwarg in ('cap_style', 'fill', 'join_style',
                      'label', 'line_style'):
            self._add_non_animated_kwarg(kwarg, kwargs)

        axes.add_arrow(**kwargs)
Пример #6
0
    def draw(self, t: float, axes: AxesFormatter):

        kwargs = {}
        if self.length is not None:
            kwargs['theta_end'] = (self.theta_start + self.length.at(t) *
                                   (self.theta_end - self.theta_start))
        else:
            kwargs['theta_end'] = self.theta_end
        for float_kwarg in ('x_center', 'y_center', 'width', 'height', 'angle',
                            'theta_start', 'alpha', 'line_width'):
            self._add_float_kwarg(float_kwarg, kwargs, t)
        for color_kwarg in ('color', 'edge_color'):
            self._add_color_kwarg(color_kwarg, kwargs, t)
        for kwarg in ('cap_style', 'join_style', 'label', 'line_style'):
            self._add_non_animated_kwarg(kwarg, kwargs)

        axes.add_arc(**kwargs)
Пример #7
0
    def draw(self, t: float, axes: AxesFormatter):

        kwargs = {}

        if self.length is not None:
            kwargs['x'] = self.x[: int(round(self.length.at(t) * len(self.x)))]
            kwargs['y'] = self.y[: int(round(self.length.at(t) * len(self.y)))]
        else:
            kwargs['x'] = self.x
            kwargs['y'] = self.y
        for float_kwarg in ('alpha', 'line_width',
                            'marker_edge_width', 'marker_size'):
            self._add_float_kwarg(float_kwarg, kwargs, t)
        for color_kwarg in ('color', 'marker_edge_color',
                            'marker_face_color', 'marker_face_color_alt'):
            self._add_color_kwarg(color_kwarg, kwargs, t)
        for kwarg in ('draw_style', 'label', 'line_style', 'marker'):
            self._add_non_animated_kwarg(kwarg, kwargs)

        axes.add_line(**kwargs)
Пример #8
0
    def plot_hist(
            self,
            color: Color = 'k',
            axf: Optional[AxesFormatter] = None
    ) -> AxesFormatter:
        """
        Plot a histogram of the distribution.

        :param color: Color of the bars.
        :param axf: Optional AxesFormatter instance.
        """
        axf = axf or AxesFormatter()
        histplot(
            data=self._data,
            color=color,
            ax=axf.axes
        )
        return axf
Пример #9
0
    def plot_bars(self,
                  color: Color = 'k',
                  pct_font_size: int = FONT_SIZE.medium,
                  max_pct: Optional[float] = None,
                  max_value: Optional[int] = None,
                  axf: Optional[AxesFormatter] = None) -> AxesFormatter:
        """
        Plot a bar plot of the counts of each category.

        :param color: Color of the bars.
        :param pct_font_size: Font size for percentage labels.
        :param max_pct: Highest percentile value of the data to plot a count of.
                        Useful for long-tail distributions.
        :param max_value: Highest value of the data to plot a count of.
        :param axf: Optional AxesFormatter instance.
        """
        axf = axf or AxesFormatter()
        counts = self.counts()
        pmf = self.pmf()
        if max_pct is not None:
            max_count = self._data.quantile(max_pct)
            counts = counts.loc[counts.index <= max_count]
            pmf = pmf.loc[counts.index]
        if max_value is not None:
            counts = counts.loc[counts.index <= max_value]
            pmf = pmf.loc[counts.index]
        counts.plot.bar(ax=axf.axes, color=color)
        percents = 100 * pmf
        axf.add_text(x=range(len(counts)),
                     y=counts,
                     text=percents.map(lambda p: f'{p: .1f}%'),
                     h_align='center',
                     v_align='bottom',
                     font_size=pct_font_size)
        axf.y_axis.set_format_integer()
        return axf
Пример #10
0
    def plot_bars(self,
                  color: Color = 'k',
                  pct_font_size: int = FONT_SIZE.medium,
                  axf: Optional[AxesFormatter] = None) -> AxesFormatter:
        """
        Plot a bar plot of the counts of each category.

        :param color: Color of the bars.
        :param pct_font_size: Font size for percentage labels.
        :param axf: Optional AxesFormatter instance.
        """
        axf = axf or AxesFormatter()
        counts = self.counts()
        pmf = self.pmf()
        counts.plot.bar(ax=axf.axes, color=color)
        percents = 100 * pmf
        axf.add_text(x=range(len(counts)),
                     y=counts,
                     text=percents.map(lambda p: f'{p: .1f}%'),
                     h_align='center',
                     v_align='bottom',
                     font_size=pct_font_size)
        axf.y_axis.set_format_integer()
        return axf
Пример #11
0
    def plot_bars(self,
                  conditional: bool = False,
                  max_pct: Optional[float] = None,
                  width: float = 0.8,
                  height: float = 0.8,
                  color: Color = 'k',
                  color_min: Optional[Color] = None,
                  alpha_min: float = 0.0,
                  edge_color: Optional[Color] = 'grey',
                  axf: Optional[AxesFormatter] = None):
        """
        Plot a set of bars for each Series, shaded by the count at each discrete
        value.

        :param conditional: Whether the shading for each Series should be
                            independent of all the others.
        :param max_pct: Highest percentile of each Series to show a bar for.
        :param width: Width of each set of bars.
        :param height: Height of each bar, centered about the count value.
        :param color: Color of each bar.
        :param color_min: Optional different color for the sparsest bars.
        :param alpha_min: Alpha value for the sparsest bars.
        :param edge_color: Optional edge color of each bar.
        :param axf: Optional AxesFormatter instance.
        """
        axf = axf or AxesFormatter()
        dist: Count
        # find highest count of all Series
        max_count = max(counts.max() for counts in self.counts())
        # find highest value for y-axis scaling
        if max_pct is None:
            # find highest value of all Series
            max_val = self.max().max()
        else:
            max_val = 0
        # plot each Series
        for x, (ix, dist) in enumerate(self._data.items()):
            dist_items = dist.counts()
            if max_pct:
                max_dist_val = dist.data.quantile(max_pct)
                dist_items = dist_items.loc[dist_items.index <= max_dist_val]
                max_val = max(max_val, max_dist_val)
            if conditional:
                max_count = dist_items.max()
            for dist_value, dist_count in dist_items.items():
                coords = dict(width=width,
                              height=height,
                              x_center=(1 + x),
                              y_center=dist_value)
                if color_min is not None:
                    from_color = set_alpha(color_min, alpha_min)
                else:
                    from_color = set_alpha(color, alpha_min)
                axf.add_rectangle(**coords,
                                  color=cross_fade(from_color=from_color,
                                                   to_color=color,
                                                   amount=dist_count /
                                                   max_count))
                if edge_color is not None:
                    axf.add_rectangle(**coords,
                                      fill=False,
                                      color=None,
                                      edge_color=edge_color)
        axf.set_x_lim(0.5, len(self._data) + 0.5)
        axf.set_y_lim(self.min().min() - 1, max_val + 1)
        axf.x_ticks.set_locations(range(1, len(self._data) + 1))
        axf.x_ticks.set_labels(self._data.index.to_list())
        axf.y_axis.set_format_integer()

        return axf
Пример #12
0
    def __init__(self, formatter: Union[AxesFormatter, FigureFormatter],
                 duration: float):

        self.formatter: AxesFormatter = formatter or AxesFormatter()
        self.duration: float = duration
        self.shapes: List[ShapeAnimation] = []
Пример #13
0
    def plot_density_bars(
            self,
            bin_spacing: float,
            conditional: bool = True,
            width: float = 0.8,
            min_pct: float = 0.0,
            max_pct: float = 1.0,
            min_alpha: float = 0.0,
            color: Color = 'k',
            edge_color: Optional[Color] = 'grey',
            axf: Optional[AxesFormatter] = None) -> AxesFormatter:
        """
        Plot a density bar for each Ratio distribution.

        :param bin_spacing: Spacing between each bin used to calculate
                            histogram of each Ratio distribution.
        :param conditional: Whether to set opacity relative to the highest bin
                            count of each Series (True) or all Series (False).
        :param width: Width of each bar,
        :param min_pct: Minimum quantile to plot from 0.0 to 1.0.
        :param max_pct: Maximum quantile to plot from 0.0 to 1.0.
        :param min_alpha: Alpha value for the opacity of bins with 0 count.
        :param color: Color for each density bar.
        :param edge_color: Color for edge of each bar. Set to None to omit
                           edges.
        :param axf: Optional AxesFormatter instance.
        """
        axf = axf or AxesFormatter()
        dist: Ratio
        # calculate histograms, counts and bin limits
        hists = [
            dist.histogram(bins=float(bin_spacing),
                           min_pct=min_pct,
                           max_pct=max_pct) for _, dist in self._data.items()
        ]
        lowest_bin = min(
            [hist.index.get_level_values('min')[0] for hist in hists])
        highest_bin = max(
            [hist.index.get_level_values('max')[-1] for hist in hists])
        max_counts = [hist.max() for hist in hists]
        max_max_count = max(max_counts)
        for x, hist, max_count in zip(range(len(self._data)), hists,
                                      max_counts):
            max_count = hist.max()
            for (min_val, max_val), count in hist.items():
                # add segment interior
                if conditional is True:
                    alpha = clip(
                        min_alpha + (1 - min_alpha) * count / max_count,
                        min_alpha, 1.0)
                else:
                    alpha = clip(
                        min_alpha + (1 - min_alpha) * count / max_max_count,
                        min_alpha, 1.0)
                axf.add_rectangle(width=width,
                                  height=max_val - min_val,
                                  x_center=x,
                                  y_center=(min_val + max_val) / 2,
                                  color=color,
                                  line_width=0,
                                  alpha=alpha)
            # add density edges
            if edge_color is not None:
                low_bin = hist.index.get_level_values('min')[0]
                high_bin = hist.index.get_level_values('max')[-1]
                axf.add_rectangle(width=width,
                                  height=high_bin - low_bin,
                                  x_center=x,
                                  y_center=(low_bin + high_bin) / 2,
                                  fill=False,
                                  color=None,
                                  edge_color=edge_color)
        # format axes
        axf.set_x_lim(-1, len(self._data))
        axf.set_y_lim(lowest_bin - bin_spacing, highest_bin + bin_spacing)
        axf.x_ticks.set_locations(range(len(self._data)))
        axf.x_ticks.set_labels(self._data.index.to_list())
        axf.y_axis.set_format_integer()

        return axf
Пример #14
0
def plot_probs(probs: DataFrame, weight):

    axf = AxesFormatter()
    data = probs.stack(
        level=['likelihood', 'prior']).rename('posterior').reset_index()
    boxplot(data=data, x='likelihood', y='posterior', hue='prior')
    axf.rotate_x_tick_labels(90)
    axf.set_y_lim(0, 1.05)
    axf.set_axis_below().grid()
    axf.set_text(title=str(weight))
    axf.show()
Пример #15
0
    def plot_density_bars(
            self,
            color: Union[Color, List[Color]] = 'k',
            color_min: Optional[Union[Color, List[Color]]] = None,
            width: Union[float, List[float]] = 0.8,
            hdi: float = 0.95,
            z_max: Optional[Union[float, List[float]]] = None,
            resolution: int = 100,
            edges: bool = False,
            orient: str = 'v',
            axf: Optional[AxesFormatter] = None) -> AxesFormatter:
        """
        Plot each distribution as a density bar.

        :param color: Color of each bar, all bars or list to cycle through.
        :param color_min: Min color of each bar, all bars or list to cycle
                          through.
        :param width: Width of each bar.
        :param hdi: Highest Density Interval width for each distribution.
        :param z_max: Optional normalizing constant to divide each bar's height
                      by.
        :param resolution: Number of density elements per unit y.
        :param edges: Whether to plot the edges of each bar.
        :param orient: Orientation. One of {'v', 'h'}.
        :param axf: Optional AxesFormatter instance.
        """
        axf = axf or AxesFormatter()
        dist: Union[PPFContinuous1dMixin, PDF1dMixin]
        num_dists = len(self._data)
        color = loop_variable(color, num_dists)
        color_min = loop_variable(color_min, num_dists)
        width = loop_variable(width, num_dists)
        z_max = loop_variable(z_max, num_dists)

        for i, (ix, dist) in enumerate(self._data.items()):
            lower, upper = dist.hdi(hdi)
            y_to_z = dist.pdf().at(linspace(lower, upper, resolution + 1))
            if orient == 'h':
                axf.add_h_density(y=i + 1,
                                  x_to_z=y_to_z,
                                  color=color[i],
                                  color_min=color_min[i],
                                  height=width[i],
                                  z_max=z_max[i],
                                  v_align='center')
                if edges:
                    axf.add_rectangle(
                        width=y_to_z.index[-1] - y_to_z.index[0],
                        height=width[i],
                        x_left=y_to_z.index[0],
                        y_bottom=i + 1 - width[i] / 2,
                        edge_color=(color if color_min is None else cross_fade(
                            color_min[i], color[i], 0.5)),
                        fill=False)
            else:
                axf.add_v_density(x=i + 1,
                                  y_to_z=y_to_z,
                                  color=color[i],
                                  color_min=color_min[i],
                                  width=width[i],
                                  z_max=z_max[i],
                                  h_align='center')
                if edges:
                    axf.add_rectangle(
                        width=width[i],
                        height=y_to_z.index[-1] - y_to_z.index[0],
                        x_left=i + 1 - width[i] / 2,
                        y_bottom=y_to_z.index[0],
                        edge_color=(color if color_min is None else cross_fade(
                            color_min[i], color[i], 0.5)),
                        fill=False)
        if orient == 'v':
            axf.set_x_lim(0, num_dists + 1)
            axf.x_ticks.set_locations(range(1, num_dists + 1))
            axf.x_ticks.set_labels(self._data.index)
            axf.y_ticks.set_locations(arange(0, 1.1, 0.1))
            axf.set_y_lim(-0.05, 1.05)
        else:
            axf.set_y_lim(0, num_dists + 1)
            axf.y_ticks.set_locations(range(1, num_dists + 1))
            axf.y_ticks.set_labels(self._data.index)
            axf.x_ticks.set_locations(arange(0, 1.1, 0.1))
            axf.set_x_lim(-0.05, 1.05)
        return axf
Пример #16
0
    def plot_kde(
            self,
            color: Color = 'k',
            inflated_value: float = 0.0,
            inflated_threshold: float = 0.5,
            mean_line: Union[bool, dict] = False,
            median_line: Union[bool, dict] = False,
            axf: Optional[AxesFormatter] = None
    ) -> AxesFormatter:
        """
        Plot a kde plot of the distribution.

        :param color: Color of the bars.
        :param inflated_value: Value which may occur disproportionately often in
                               the data due to a mixed distribution (e.g. ZIPF),
                               or collection method
                               (e.g. code all > 100% as 101%).
        :param inflated_threshold: Proportion of values which must equal the
                                   inflated_value to add an additional line to
                                   plot the non-inflated curve.
        :param mean_line: Boolean flag to indicate whether to annotate the mean.
                          Or dict of kws to pass to AxesFormatter.add_v_line.
        :param median_line: Boolean flag to indicate whether to annotate the
                            mean.
                            Or dict of kws to pass to AxesFormatter.add_v_line.
        :param axf: Optional AxesFormatter instance.
        """
        axf = axf or AxesFormatter()
        # main kde
        kdeplot(
            data=self._data,
            color=color,
            ax=axf.axes,
            label=self.name
        )
        # non-inflated kde
        non_zero = self._data.loc[self._data != inflated_value]
        if len(non_zero) / len(self._data) > inflated_threshold:
            kdeplot(
                data=non_zero,
                color=color,
                ls='--',
                ax=axf.axes,
                label=f'{self.name} != {inflated_value}'
            )
        # draw mean
        mean_line_kws = None
        if isinstance(mean_line, bool):
            if mean_line is True:
                mean_line_kws = {
                    'color': 'green',
                    'line_style': ':'
                }
        else:
            mean_line_kws = mean_line
        if mean_line_kws is not None:
            mean = self._data.mean()
            axf.add_v_line(x=mean, **mean_line_kws)
            y_min, y_max = axf.get_y_min(), axf.get_y_max()
            axf.add_text(x=mean, y=y_min + 0.8 * (y_max - y_min),
                         text=f'mean = {mean: 0.1f}')
        # draw median
        median_line_kws = None
        if isinstance(median_line, bool):
            if median_line is True:
                median_line_kws = {
                    'color': 'blue',
                    'line_style': ':'
                }
        else:
            median_line_kws = median_line
        if median_line_kws is not None:
            median = self._data.median()
            axf.add_v_line(x=median, **median_line_kws)
            y_min, y_max = axf.get_y_min(), axf.get_y_max()
            axf.add_text(x=median, y=y_min + 0.7 * (y_max - y_min),
                         text=f'median = {median: 0.1f}')
        return axf
Пример #17
0
    def plot_comparison_bars(
            self,
            other: 'DataDiscreteNumericMixin',
            absolute: bool = False,
            color: Tuple[Color, Color] = ('C0', 'C1'),
            width: float = 0.5,
            label_pcts: bool = True,
            label_counts: bool = False,
            label_size: Optional[FontSize] = FONT_SIZE.medium,
            max_pct: Optional[float] = None,
            max_value: Optional[int] = None,
            axf: Optional[AxesFormatter] = None) -> AxesFormatter:
        """
        Plot a comparison of the 2 ordinals, with outlines around bars that
        are significantly higher or lower than others.

        :param other: Another Ordinal with the same categories.
        :param absolute: Whether to plot bar heights as absolute values or
                        percentages.
        :param color: Color for each set of bars.
        :param width: Total width of each pair of bars.
        :param label_pcts: Whether to add percentage labels.
        :param label_counts: Whether to add count labels.
        :param label_size: Font size for bar labels.
        :param max_pct: Highest percentile value of the data to plot a count of.
                        Useful for long-tail distributions.
        :param max_value: Highest value of the data to plot a count of.
        :param axf: Optional AxesFormatter instance.
        """
        # validation
        if self.name == other.name:
            raise ValueError(
                'Distributions must have different names in order to compare.')
        self_cats, other_cats = set(self._categories), set(other._categories)
        categories = sorted(self_cats.union(other_cats))
        # get data
        self_counts = self._data.value_counts().reindex(categories)
        other_counts = other._data.value_counts().reindex(categories)
        count_data = concat([self_counts, other_counts], axis=1)
        pct_data = count_data / count_data.sum()
        if max_pct is not None:
            max_count = max(self._data.quantile(max_pct),
                            other._data.quantile(max_pct))
            self_counts = self_counts.loc[self_counts.index <= max_count]
            other_counts = other_counts.loc[other_counts.index <= max_count]
        if max_value is not None:
            self_counts = self_counts.loc[self_counts.index <= max_value]
            other_counts = other_counts.loc[other_counts.index <= max_value]
        count_data = concat([self_counts, other_counts], axis=1)
        pct_data = pct_data.loc[count_data.index]
        # plot bars
        axf = axf or AxesFormatter()
        if absolute:
            plot_data = count_data
        else:
            plot_data = pct_data
        plot_data.plot.bar(ax=axf.axes, color=color, width=width)
        # add labels
        for o, ordinal in enumerate(plot_data.columns):
            for i, ix in enumerate(plot_data[ordinal].index):
                if not label_pcts and not label_counts:
                    continue
                label_x = i - width / 4 + o * width / 2
                label_y = plot_data.loc[ix, ordinal]
                texts = []
                if label_pcts:
                    texts.append(format_as_percent(pct_data.loc[ix, ordinal]))
                if label_counts:
                    texts.append(format_as_integer(count_data.loc[ix,
                                                                  ordinal]))
                text = ' | '.join(texts)
                axf.add_text(label_x,
                             label_y,
                             text,
                             font_size=label_size,
                             h_align='center',
                             v_align='bottom')
        # format y-axis
        if not absolute:
            axf.y_axis.set_format_percent()
        else:
            axf.y_axis.set_format_integer()

        return axf
Пример #18
0
    def plot_conditional_dist_densities(
            self,
            categorical: Union[DataDistributionMixin, DataCategoriesMixin],
            width: float = 0.8,
            heights: float = 0.9,
            color: Color = 'k',
            color_min: Optional[Color] = None,
            color_mean: Optional[Color] = None,
            color_median: Optional[Color] = None,
            axf: Optional[AxesFormatter] = None
    ) -> AxesFormatter:
        """
        Plot conditional probability densities of the data, split by the
        categories of an Ordinal or Nominal distribution.

        :param categorical: Nominal or Ordinal distribution.
        :param width: Width of each density bar.
        :param heights: Height of each density bar.
        :param color: Color for the densest part of each distribution.
        :param color_min: Color for the sparsest part of each distribution,
                          if different to color.
        :param color_mean: Color for mean data markers.
        :param color_median: Color for median data markers.
        :param axf: Optional AxesFormatter to plot on.
        """
        axf = axf or AxesFormatter()
        cats = categorical.categories
        n_cats = len(cats)
        # filter categorical data
        shared_ix = list(
            set(self._data.index).intersection(categorical.data.index)
        )
        cat_data = categorical.data.loc[shared_ix]
        ordinal_data = self._data.loc[shared_ix]
        max_cat_sum = max([
            ordinal_data.loc[cat_data == category].value_counts().sum()
            for category in categorical.categories
        ])
        max_pct = max([
            ordinal_data.loc[cat_data == category].value_counts().max() /
            ordinal_data.loc[cat_data == category].value_counts().sum()
            for category in categorical.categories
        ])
        for c, category in enumerate(categorical.categories):
            cat_ratio_data = ordinal_data.loc[cat_data == category]
            value_counts = cat_ratio_data.value_counts().reindex(
                self.categories).fillna(0)
            cat_sum = value_counts.sum()
            for i, (item, count) in enumerate(value_counts.items()):
                pct = count / cat_sum
                if color_min is not None:
                    rect_color = cross_fade(color_min, color, pct / max_pct)
                else:
                    rect_color = color
                bar_width = width * cat_sum / max_cat_sum
                x_center = 1 + c
                y_center = 1 + i
                axf.add_rectangle(
                    width=bar_width, height=heights,
                    x_left=x_center - bar_width / 2,
                    y_bottom=y_center - heights / 2,
                    color=rect_color,
                    alpha=pct / max_pct
                )
                axf.add_text(x=x_center, y=y_center,
                             text=format_as_percent(pct, 1),
                             h_align='center', v_align='center',
                             bbox_edge_color='k', bbox_fill=True,
                             bbox_face_color='white')
            if len(cat_ratio_data) == 0:
                continue
            # plot descriptive statistics lines
            if color_mean is not None:
                interval = 1 + cat_ratio_data.cat.codes
                mean = interval.mean()
                axf.add_line(x=[c + 0.55, c + 1.45], y=[mean, mean],
                             color=color_mean)
            if color_median is not None:
                interval = 1 + cat_ratio_data.cat.codes
                median = interval.median()
                axf.add_line(x=[c + 0.55, c + 1.45], y=[median, median],
                             color='g')
        # # labels
        axf.set_text(
            title=f'Distributions of p({self.name}|{categorical.name})',
            x_label=categorical.name,
            y_label=self.name
        )
        # # axes
        axf.set_x_lim(0.5, n_cats + 0.5)
        # yy_range = yy_max - yy_min
        axf.set_y_lim(0, len(self._categories) + 1)
        axf.x_ticks.set_locations(range(1, n_cats + 1)).set_labels(cats)
        axf.y_ticks.set_locations(
            range(1, len(self._categories) + 1)).set_labels(self._categories)

        return axf
Пример #19
0
    def plot_comparison_bars(
            self,
            other: 'DataDiscreteCategoricalMixin',
            absolute: bool = False,
            color: Tuple[Color, Color] = ('C0', 'C1'),
            width: float = 0.5,
            label_pcts: bool = True,
            label_counts: bool = False,
            label_size: Optional[FontSize] = FONT_SIZE.medium,
            axf: Optional[AxesFormatter] = None) -> AxesFormatter:
        """
        Plot a comparison of the 2 ordinals, with outlines around bars that
        are significantly higher or lower than others.

        :param other: Another Ordinal with the same categories.
        :param absolute: Whether to plot bar heights as absolute values or
                        percentages.
        :param color: Color for each set of bars.
        :param width: Total width of each pair of bars.
        :param label_pcts: Whether to add percentage labels.
        :param label_counts: Whether to add count labels.
        :param label_size: Font size for bar labels.
        :param axf: Optional AxesFormatter instance.
        """
        # validation
        if self.name == other.name:
            raise ValueError(
                'Distributions must have different names in order to compare.')
        self_cats, other_cats = set(self._categories), set(other._categories)
        if self_cats != other_cats:
            str_warning = f'WARNING: Distributions contain different categories'
            unique_1 = self_cats.difference(other_cats)
            unique_2 = other_cats.difference(self_cats)
            if unique_1:
                str_warning += f'\nOnly in {self.name}: {", ".join(unique_1)}'
            if unique_2:
                str_warning += f'\nOnly in {other.name}: {", ".join(unique_2)}'
            print(str_warning)
        # get data
        self_counts = self._data.value_counts().reindex(self._categories)
        other_counts = other._data.value_counts().reindex(self._categories)
        count_data = concat([self_counts, other_counts], axis=1)
        pct_data = count_data / count_data.sum()
        # plot bars
        axf = axf or AxesFormatter()
        if absolute:
            plot_data = count_data
        else:
            plot_data = pct_data
        plot_data.plot.bar(ax=axf.axes, color=color, width=width)
        # add labels
        for o, ordinal in enumerate(plot_data.columns):
            for i, ix in enumerate(plot_data[ordinal].index):
                if not label_pcts and not label_counts:
                    continue
                label_x = i - width / 4 + o * width / 2
                label_y = plot_data.loc[ix, ordinal]
                texts = []
                if label_pcts:
                    texts.append(format_as_percent(pct_data.loc[ix, ordinal]))
                if label_counts:
                    texts.append(format_as_integer(count_data.loc[ix,
                                                                  ordinal]))
                text = ' | '.join(texts)
                axf.add_text(label_x,
                             label_y,
                             text,
                             font_size=label_size,
                             h_align='center',
                             v_align='bottom')
        # format y-axis
        if not absolute:
            axf.y_axis.set_format_percent()
        else:
            axf.y_axis.set_format_integer()

        return axf
Пример #20
0
    def plot_conditional_prob_densities(
            self,
            categorical: Union[DataDistributionMixin, DataCategoriesMixin],
            hdi: float = 0.95,
            width: float = 0.8,
            num_segments: int = 100,
            color: Color = 'k',
            color_min: Optional[Color] = None,
            color_mean: Optional[Color] = None,
            edge_color: Optional[Color] = None,
            axf: Optional[AxesFormatter] = None) -> AxesFormatter:
        """
        Plot conditional probability densities of the data, split by the
        categories of an Ordinal or Nominal distribution.

        :param categorical: Nominal or Ordinal distribution.
        :param hdi: Highest Density Interval width for each distribution.
        :param width: Width of each density bar.
        :param num_segments: Number of segments to plot per density.
        :param color: Color for the densest part of each distribution.
        :param color_min: Color for the sparsest part of each distribution,
                          if different to color.
        :param color_mean: Color for mean data markers.
        :param edge_color: Optional color for the edge of each density bar.
        :param axf: Optional AxesFormatter to plot on.
        """
        axf = axf or AxesFormatter()

        cats = categorical.categories
        n_cats = len(cats)
        yy_min, yy_max = inf, -inf
        # filter categorical data
        shared_ix = list(
            set(self._data.index).intersection(categorical.data.index))
        cat_data = categorical.data.loc[shared_ix]
        ratio_data = self._data.loc[shared_ix]
        for c, category in enumerate(categorical.categories):
            cat_ratio_data = ratio_data.loc[cat_data == category]
            if len(cat_ratio_data) == 0:
                continue
            # fit distribution and find limits for HDI
            cat_dist = BetaBinomialConjugate.infer_posterior(cat_ratio_data)
            # cat_dist = Beta.fit(data=cat_ratio_data)
            y_min, y_max = cat_dist.hdi(hdi)
            yy_min, yy_max = min(y_min, yy_min), max(y_max, yy_max)
            # plot density
            axf.add_v_density(x=c + 1,
                              y_to_z=cat_dist.pdf().at(
                                  linspace(y_min, y_max, num_segments + 1)),
                              color=color,
                              color_min=color_min,
                              edge_color=edge_color,
                              width=width)
            # plot descriptive statistics lines
            if color_mean is not None:
                mean = cat_ratio_data.mean()
                axf.add_line(x=[c + 0.55, c + 1.45],
                             y=[mean, mean],
                             color=color_mean)
        # labels
        axf.set_text(title=f'{hdi: .0%} HDIs of $p(' + r'p_{' + self.name +
                     r'}' + f'|{categorical.name})$',
                     x_label=categorical.name,
                     y_label=r'$p_{' + self.name + r'}$')
        # axes
        axf.set_x_lim(0, n_cats + 1)
        yy_range = yy_max - yy_min
        axf.set_y_lim(yy_min - yy_range * 0.05, yy_max + yy_range * 0.05)
        axf.y_ticks.set_locations(linspace(0, 1, 11))
        axf.x_ticks.set_locations(range(1, n_cats + 1)).set_labels(cats)

        return axf
Пример #21
0
    def plot(self,
             x: Optional[Iterable],
             kind: str = 'line',
             color: str = 'C0',
             mean: bool = False,
             median: bool = False,
             mode: bool = False,
             std: bool = False,
             ax: Optional[Axes] = None,
             **kwargs) -> Axes:
        """
        Plot the function.

        :param x: Range of values of x to plot p(x) over.
        :param kind: Kind of plot e.g. 'bar', 'line'.
        :param color: Optional color for the series.
        :param mean: Whether to show marker and label for the mean.
        :param median: Whether to show marker and label for the median.
        :param mode: Whether to show marker and label for the mode.
        :param std: Whether to show marker and label for the standard deviation.
        :param ax: Optional matplotlib axes to plot on.
        :param kwargs: Additional arguments for the matplotlib plot function.
        """
        if x is None:
            if (hasattr(self._parent, 'lower_bound')
                    and hasattr(self._parent, 'upper_bound')):
                x = linspace(self._parent.lower_bound,
                             self._parent.upper_bound, 1001)
            else:
                raise ValueError('Must pass x if distribution has no bounds.')

        data: Series = self.at(x)
        axf = AxesFormatter(axes=ax)
        ax = axf.axes

        if self._method_name in ('pdf', 'cdf', 'logpdf'):
            if 'label' not in kwargs.keys():
                kwargs['label'] = self._parent.label
            data.plot(kind=kind, color=color, ax=axf.axes, **kwargs)
        else:
            raise ValueError('plot not implemented for {}'.format(self._name))

        # stats
        y_min = axf.get_y_min()
        y_max = axf.get_y_max()
        x_mean = self._distribution.mean()
        if mean:
            axf.add_v_lines(x=x_mean,
                            y_min=y_min,
                            y_max=y_max,
                            line_styles='--',
                            colors=color)
            axf.add_text(x=x_mean,
                         y=self._distribution.pdf(x_mean),
                         text=f'mean={x_mean: 0.3f}',
                         color=color,
                         h_align='center',
                         v_align='bottom')
        if median:
            x_median = self._distribution.median()
            axf.add_v_lines(x=x_median,
                            y_min=y_min,
                            y_max=y_max,
                            line_styles='-.',
                            colors=color)
            axf.add_text(x=x_median,
                         y=self._distribution.pdf(x_median),
                         text=f'median={x_median: 0.3f}',
                         color=color,
                         h_align='center',
                         v_align='bottom')
        if mode:
            x_mode = self._parent.mode()
            axf.add_v_lines(x=x_mode,
                            y_min=y_min,
                            y_max=y_max,
                            line_styles='-.',
                            colors=color)
            axf.add_text(x=x_mode,
                         y=self._distribution.pdf(x_mode),
                         text=f'mode={x_mode: 0.3f}',
                         color=color,
                         h_align='center',
                         v_align='bottom')
        if std:
            x_std = self._distribution.std()
            axf.add_v_lines(x=[x_mean - x_std, x_mean + x_std],
                            y_min=y_min,
                            y_max=y_max,
                            line_styles=':',
                            colors=color)
            axf.add_text(x=x_mean - x_std / 2,
                         y=self._distribution.pdf(x_mean - x_std / 2),
                         text=f'std={x_std: 0.3f}',
                         color=color,
                         h_align='center',
                         v_align='bottom')

        ax.set_xlabel(self._parent.x_label)

        if self._parent.y_label:
            ax.set_ylabel(self._parent.y_label)
        else:
            if self._method_name == 'pdf':
                ax.set_ylabel('P(X = x)')
            elif self._method_name == 'cdf':
                ax.set_ylabel('P(X ≤ x)')
            elif self._method_name == 'logpdf':
                ax.set_ylabel('log P(X = x)')
            else:
                ax.set_ylabel(self._name)

        return ax
Пример #22
0
    def plot_density_bars(
            self,
            color: Union[Color, List[Color]] = 'k',
            color_min: Optional[Union[Color, List[Color]]] = None,
            group_width: float = 0.8,
            stagger: bool = True,
            item_width: float = 0.8,
            hdi: float = 0.95,
            resolution: int = 100,
            z_max: Optional[Union[float, List[float]]] = None,
            log_z: bool = False,
            axf: Optional[AxesFormatter] = None
    ) -> AxesFormatter:
        """
        Plot each row of distributions as a group of density bars.

        :param color: Color for each column, or all bars.
        :param color_min: Min color for each column, or all bars.
        :param group_width: Width of each column group.
        :param item_width: Width of each item as a proportion of the group
                           width divided by the number of items per group.
        :param stagger: Whether to plot items within a row next to each
                        other.
        :param hdi: Highest Density Interval width for each distribution.
        :param z_max: Optional normalizing constant to divide each bar's height
                      by.
        :param resolution: Number of density elements per unit y.
        :param log_z: Whether to take the log of z before plotting.
        :param axf: Optional AxesFormatter instance.
        """
        axf = axf or AxesFormatter()
        n_rows = self._data.shape[0]
        n_cols = self._data.shape[1]
        if not stagger:
            width_per_item = group_width
            item_centers = [0] * n_cols
        else:
            width_per_item = group_width * item_width / n_cols
            item_centers = linspace(
                -group_width / 2 + width_per_item / 2,
                group_width / 2 - width_per_item / 2,
                n_cols
            )
        color = loop_variable(color, n_cols)
        color_min = loop_variable(color_min, n_cols)
        # add distributions
        dist: Union[PDF1dMixin, PPFContinuous1dMixin]
        for i_row, (row_name, betas) in enumerate(self._data.iterrows()):
            for i_col, (col_name, dist) in enumerate(betas.items()):
                y_min, y_max = dist.hdi(hdi)
                n_bars = round(resolution * (y_max - y_min))
                y_to_z = dist.pdf().at(linspace(
                    y_min, y_max,
                    n_bars + 1
                ))
                if log_z:
                    y_to_z = y_to_z.map(log)
                axf.add_v_density(
                    x=i_row + 1 + item_centers[i_col],
                    y_to_z=y_to_z,
                    color=color[i_col], color_min=color_min[i_col],
                    width=width_per_item,
                    z_max=z_max,
                    h_align='center'
                )
        # axes
        axf.set_x_lim(0, n_rows + 1)
        axf.x_ticks.set_locations(range(1, n_rows + 1))
        axf.x_ticks.set_labels(self._data.index)
        axf.y_ticks.set_locations(arange(0, 1.1, 0.1))
        axf.set_y_lim(-0.05, 1.05)
        # legend
        patches = []
        for i_col in range(n_cols):
            patches.append(Patch(
                color=color[i_col],
                label=self._data.columns[i_col]
            ))
        axf.axes.legend(handles=patches)
        return axf
Пример #23
0
    def plot_densities(data: Optional[DataFrame] = None,
                       labels: Union[Series, str] = None,
                       distributions: Union[Series, str] = None,
                       color: Color = 'k',
                       color_min: Optional[Color] = None,
                       width: float = 0.8,
                       num_strips: int = 100,
                       ax: Optional[Axes] = None) -> Axes:
        """
        Plot a density plot (continuous boxplot) of distribution pdfs.

        :param data: Optional DataFrame containing labels and distributions.
        :param labels: Series of labels or name of column.
        :param distributions: Series of distributions or name of column.
        :param num_strips: Number of strips for each density bar.
        :param color: The color of the density bar.
        :param color_min: Optional 2nd color to fade out to.
        :param width: The bar width.
        :param ax: Optional matplotlib Axes instance.
        """
        # check arguments
        ax = ax or new_axes()
        if data is None:
            if not (isinstance(labels, Series)
                    and isinstance(distributions, Series)):
                raise TypeError('If data is not given, '
                                'labels and distributions must both be Series')
        else:
            if not isinstance(data, DataFrame):
                raise TypeError('data must be a DataFrame')
            if not (isinstance(labels, str)
                    and isinstance(distributions, str)):
                raise TypeError('If data is given, '
                                'labels and distributions must both be str')
            labels: Series = data[labels]
            distributions: Series = data[distributions]
        # plot densities
        axf = AxesFormatter(ax)
        distribution: RVContinuous1dMixin
        y_to_z: Dict[Any, Series] = {}
        max_z = 0
        min_y = 1e6
        max_y = -1e6
        for x_label, distribution in zip(labels, distributions):
            y_dist = linspace(distribution.lower_bound,
                              distribution.upper_bound, num_strips + 1)
            min_y = min(min_y, y_dist[0])
            max_y = max(max_y, y_dist[-1])
            y_to_z[x_label] = Series(index=y_dist,
                                     data=distribution.pdf().at(y_dist))
            max_z = max(max_z, y_to_z[x_label].max())
        label_ix = {label: i + 1 for i, label in enumerate(list(labels))}
        for label, z_values in y_to_z.items():
            axf.add_v_density(x=label_ix[label],
                              y_to_z=z_values,
                              color=color,
                              color_min=color_min,
                              width=width,
                              z_max=max_z)
        num_labels = len(labels)
        axf.x_axis.axis.set_ticks(range(1, num_labels + 1))
        axf.x_axis.axis.set_ticklabels(labels)
        axf.set_x_lim(0, num_labels + 1)
        axf.set_y_lim(floor(min_y), ceil(max_y))
        axf.set_text(y_label='x')
        return axf.axes
Пример #24
0
    def plot(self,
             k: Optional[Iterable[int]],
             color: str = 'C0',
             kind: str = 'bar',
             mean: bool = False,
             median: bool = False,
             std: bool = False,
             ax: Optional[Axes] = None,
             **kwargs) -> Axes:
        """
        Plot the function.

        :param k: Range of values of k to plot p(k) over.
        :param color: Optional color for the series.
        :param kind: Kind of plot e.g. 'bar', 'line'.
        :param mean: Whether to show marker and label for the mean.
        :param median: Whether to show marker and label for the median.
        :param std: Whether to show marker and label for the standard deviation.
        :param ax: Optional matplotlib axes to plot on.
        :param kwargs: Additional arguments for the matplotlib plot function.
        """
        if k is None:
            if (hasattr(self._parent, 'lower_bound')
                    and hasattr(self._parent, 'upper_bound')):
                k = range(self._parent.lower_bound,
                          self._parent.upper_bound + 1)
            else:
                raise ValueError('Must pass k if distribution has no bounds.')

        data: Series = self.at(k)
        axf = AxesFormatter(axes=ax)
        ax = axf.axes

        # special kwargs
        vlines = None
        if 'vlines' in kwargs.keys():
            vlines = kwargs.pop('vlines')
        if 'label' not in kwargs.keys():
            kwargs['label'] = self._parent.label

        if self._method_name == 'pmf':
            data.plot(kind=kind, color=color, ax=axf.axes, **kwargs)
        elif self._method_name == 'cdf':
            data.plot(kind='line',
                      color=color,
                      drawstyle='steps-post',
                      ax=axf.axes,
                      **kwargs)
        else:
            raise ValueError('plot not implemented for {}'.format(self._name))
        if vlines:
            axf.axes.vlines(x=k, ymin=0, ymax=data.values, color=color)

        y_min = axf.get_y_min()
        y_max = axf.get_y_max()
        x_mean = self._distribution.mean()
        if mean:
            axf.add_v_lines(x=x_mean,
                            y_min=y_min,
                            y_max=y_max,
                            line_styles='--',
                            colors=color)
            axf.add_text(x=x_mean,
                         y=self._distribution.pmf(x_mean),
                         text=f'mean={x_mean: 0.3f}',
                         color=color,
                         h_align='center',
                         v_align='bottom')
        if median:
            x_median = self._distribution.median()
            axf.add_v_lines(x=x_median,
                            y_min=y_min,
                            y_max=y_max,
                            line_styles='-.',
                            colors=color)
            axf.add_text(x=x_median,
                         y=self._distribution.pmf(x_median),
                         text=f'median={x_median: 0.3f}',
                         color=color,
                         h_align='center',
                         v_align='bottom')
        if std:
            x_std = self._distribution.std()
            axf.add_v_lines(x=[x_mean - x_std, x_mean + x_std],
                            y_min=y_min,
                            y_max=y_max,
                            line_styles=':',
                            colors=color)
            axf.add_text(x=x_mean - x_std / 2,
                         y=self._distribution.pmf(x_mean - x_std / 2),
                         text=f'std={x_std: 0.3f}',
                         color=color,
                         h_align='center',
                         v_align='bottom')

        ax.set_xlabel(self._parent.x_label)

        if self._parent.y_label:
            ax.set_ylabel(self._parent.y_label)
        else:
            if self._method_name == 'pmf':
                ax.set_ylabel('P(K = k)')
            elif self._method_name == 'cdf':
                ax.set_ylabel('P(K ≤ k)')
            else:
                ax.set_ylabel(self._name)

        return ax
Пример #25
0
wins ∝ P($|R=win)
losses ∝ P($|R=loss)
"""


wins: Series = \
    300 * Poisson(lambda_=8).pmf().at(range(1, 51)).rename('win')
wins.index = range(100_000, 5_000_001, 100_000)
wins = wins.round(0).astype(int)
losses: Series = \
    700 * Poisson(lambda_=10).pmf().at(range(1, 51)).rename('loss')
losses.index = range(100_000, 5_000_001, 100_000)
losses = losses.round(0).astype(int)

data = concat([wins, losses], axis=1)[['loss', 'win']]
axf = AxesFormatter()
data.plot.bar(ax=axf.axes)
axf.show()


def plot_probs(probs: DataFrame, weight):

    axf = AxesFormatter()
    data = probs.stack(
        level=['likelihood', 'prior']).rename('posterior').reset_index()
    boxplot(data=data, x='likelihood', y='posterior', hue='prior')
    axf.rotate_x_tick_labels(90)
    axf.set_y_lim(0, 1.05)
    axf.set_axis_below().grid()
    axf.set_text(title=str(weight))
    axf.show()