Пример #1
0
def plot_speeds(
    body_speed,
    start: int = 0,
    end: int = -1,
    is_moving: np.ndarray = None,
    ax: plt.axis = None,
    show: bool = True,
    **other_speeds,
):
    """
        Just plot some speeds for debugging stuff
    """
    if ax is None:
        f, ax = plt.subplots(figsize=(14, 8))

    ax.plot(body_speed[start:end], label="body", color="salmon")
    for name, speed in other_speeds.items():
        ax.plot(speed[start:end], label=name)

    if is_moving is not None:
        ax.plot(is_moving, lw=4, color="k")

    ax.legend()
    ax.set(xlabel="time (frames)", ylabel="speed (cm/s)")
    if show:
        plt.show()
Пример #2
0
    def mpl_plot(self,
                 ax: plt.axis = None,
                 temp_unit: str = "GK",
                 **kwargs) -> plt.axis:

        ax = ax or plt.gca()

        t = np.logspace(-2, 1, 1000)
        if temp_unit is "GK":
            ax.loglog(
                t,
                self.rate(t),
                color=self.color,
                label="Reaclib-" + self.label + " " + self.__str__(),
                **kwargs,
            )
        elif temp_unit is "KeV":
            ax.loglog(
                Temperature(t).kev,
                self.rate(t),
                color=self.color,
                label=self.label + " " + self.__str__(),
                **kwargs,
            )

        ax.legend()

        ax = super().mpl_plot(ax=ax, temp_unit=temp_unit)

        return ax
Пример #3
0
    def plot_on(
        self,
        ax: plt.axis,
        draw_legend: bool = True,
        legend_inside: bool = True,
        legend_kwargs: dict = {},
        yaxis_scale=1.3,
        hide_labels: bool = False,
    ) -> plt.axis:

        for component in self.components:
            if component.style == 'point':
                ax.errorbar(
                    component.data.x_values,
                    component.data.y_values,
                    yerr=component.data.y_errors,
                    xerr=component.data.x_errors,
                    label=component.label,
                    color=component.color,
                    ls=component.ls,
                    marker=component.marker,
                )
            elif component.style == 'box':
                ax.bar(component.data.x_values,
                       2 * component.data.y_errors,
                       width=2 * component.data.x_errors,
                       bottom=component.data.y_values -
                       component.data.y_errors,
                       label=component.label,
                       color=component.color,
                       alpha=0.5)
            else:
                raise NotImplementedError(
                    "Options: point|box. If you require a new kind of plot, report a feature request"
                )

        if not hide_labels:
            ax.set_xlabel(self.variable.x_label, plot_style.xlabel_pos)
            ax.set_ylabel(self.variable.y_label, plot_style.ylabel_pos)

        if draw_legend:
            if legend_inside:
                ax.legend(frameon=False, **legend_kwargs)
                ylims = ax.get_ylim()
                ax.set_ylim(ylims[0], yaxis_scale * ylims[1])
            else:
                ax.legend(frameon=False,
                          bbox_to_anchor=(1, 1),
                          **legend_kwargs)

        return ax
Пример #4
0
    def _sub_plot_multiple(self, axis:plt.axis, columns:tuple):
        '''
        Create a sub-plot for the given columns. 

        Parameters
        ----------
        ax : plt.axis
            figure axis
        columns : tuple
            column names
        '''
        title = ''
        for col in columns:
            axis.plot(self.m_data_df[col], label=col,linestyle='-',linewidth=0.1)
            title += col + ', '
        axis.set_title(title) 
        axis.set_xlabel('Rollouts')
        axis.legend()
Пример #5
0
    def mpl_plot(self, ax: plt.axis = None, temp_unit: str = "GK", **kwargs):

        ax = ax or plt.gca()

        if temp_unit is "GK":
            ax.errorbar(
                Temperature(np.array(self.temperature),
                            unit=self.temp_unit).gk,
                self.rr,
                yerr=self.err,
                color=self.colour,
                label="{0} {1}".format(self.label, self.__str__()),
                **kwargs,
            )
            ax.scatter(
                Temperature(np.array(self.temperature),
                            unit=self.temp_unit).gk,
                self.rr,
                color=self.colour,
            )
        elif temp_unit is "KeV":
            ax.errorbar(
                Temperature(np.array(self.temperature),
                            unit=self.temp_unit).kev,
                self.rr,
                yerr=self.err,
                color=self.colour,
                label="{0} {1}".format(self.label, self.__str__()),
                **kwargs,
            )
            ax.scatter(
                Temperature(np.array(self.temperature),
                            unit=self.temp_unit).kev,
                self.rr,
                color=self.colour,
            )

        ax.set_yscale("log")

        ax.legend()

        ax = super().mpl_plot(ax, temp_unit=temp_unit)

        return ax
Пример #6
0
def add_cut_to_axis(ax: plt.axis,
                    cut_left: Optional[float] = None,
                    cut_right: Optional[float] = None,
                    cut_window: Optional[Tuple[float, float]] = None,
                    keep_window: Optional[Tuple[float, float]] = None,
                    color: str = 'white'):
    """
    Adds a "cut" to a given axis. The cut is shown as shaded area with the color
    given in the parameter color.

    :param ax: Axis to plot on.
    :param cut_left: Upper x value of the cut. If set, the area with
    x < cut_left is indicated to be cut away. Default is None
    :param cut_right: Lower x value of the cut. If set, the area with
    x > cut_right is indicated to be cut away. Default is None
    :param cut_window:
    :param keep_window:
    :param color: Color of the overlay of the area which is indicated to be
    cut away. Default is 'white'.
    """
    x_lim_low, x_lim_high = ax.get_xlim()

    if cut_left is not None:
        ax.axvspan(x_lim_low, cut_left, facecolor=color, alpha=0.7)
        ax.axvline(cut_left, color='black', linestyle='dashed', lw=1.5, label='Cut')
    elif cut_right is not None:
        ax.axvspan(cut_right, x_lim_high, facecolor=color, alpha=0.7)
        ax.axvline(cut_right, color='black', linestyle='dashed', lw=1.5, label='Cut')
    elif cut_window is not None:
        ax.axvspan(cut_window[0], cut_window[1], facecolor=color, alpha=0.7)
        ax.axvline(cut_window[0], color='black', linestyle='dashed', lw=1.5, label='Cut')
        ax.axvline(cut_window[1], color='black', linestyle='dashed', lw=1.5)
    elif keep_window is not None:
        ax.axvspan(x_lim_low, keep_window[0], facecolor=color, alpha=0.7)
        ax.axvline(keep_window[0], color='black', linestyle='dashed', lw=1.5, label='Cut')
        ax.axvspan(keep_window[1], x_lim_high, facecolor=color, alpha=0.7)
        ax.axvline(keep_window[1], color='black', linestyle='dashed', lw=1.5, label='Cut')

    ax.legend(frameon=False, bbox_to_anchor=(1, 1))
Пример #7
0
    def plot_on(self,
                ax: plt.axis,
                ylabel="Events",
                draw_legend=True,
                legend_inside=True,
                hide_labels: bool = False):
        bin_edges, bin_mids, bin_width = self._get_bin_edges()

        self._bin_edges = bin_edges
        self._bin_mids = bin_mids
        self._bin_width = bin_width

        ax.hist(
            x=[comp.data for comp in self._mc_components['stacked']],
            bins=bin_edges,
            weights=[comp.weights for comp in self._mc_components['stacked']],
            stacked=True,
            edgecolor="black",
            lw=0.3,
            color=[comp.color for comp in self._mc_components['stacked']],
            label=[comp.label for comp in self._mc_components['stacked']],
            histtype='stepfilled')

        if not hide_labels:
            ax.set_xlabel(self._variable.x_label, plot_style.xlabel_pos)
            y_label = self._get_y_label(False, bin_width, ylabel)
            ax.set_ylabel(y_label, plot_style.ylabel_pos)
        if draw_legend:
            if legend_inside:
                ax.legend(frameon=False)
                ylims = ax.get_ylim()
                ax.set_ylim(ylims[0], 1.4 * ylims[1])

            else:
                ax.legend(frameon=False, bbox_to_anchor=(1, 1))

        return ax
Пример #8
0
    def plot_on(
        self,
        ax1: plt.axis,
        ax2,
        style="stacked",
        ylabel="Events",
        sum_color=plot_style.KITColors.kit_purple,
        draw_legend: bool = True,
        legend_inside: bool = True,
    ):
        bin_edges, bin_mids, bin_width = self._get_bin_edges()

        self._bin_edges = bin_edges
        self._bin_mids = bin_mids
        self._bin_width = bin_width

        sum_w = np.sum(np.array([
            binned_statistic(comp.data,
                             comp.weights,
                             statistic="sum",
                             bins=bin_edges)[0]
            for comp in self._mc_components["MC"]
        ]),
                       axis=0)

        sum_w2 = np.sum(np.array([
            binned_statistic(comp.data,
                             comp.weights**2,
                             statistic="sum",
                             bins=bin_edges)[0]
            for comp in self._mc_components["MC"]
        ]),
                        axis=0)

        hdata, _ = np.histogram(self._data_component.data, bins=bin_edges)

        if style.lower() == "stacked":
            ax1.hist(
                x=[comp.data for comp in self._mc_components['MC']],
                bins=bin_edges,
                weights=[comp.weights for comp in self._mc_components['MC']],
                stacked=True,
                edgecolor="black",
                lw=0.3,
                color=[comp.color for comp in self._mc_components['MC']],
                label=[comp.label for comp in self._mc_components['MC']],
                histtype='stepfilled')

            ax1.bar(x=bin_mids,
                    height=2 * np.sqrt(sum_w2),
                    width=self.bin_width,
                    bottom=sum_w - np.sqrt(sum_w2),
                    color="black",
                    hatch="///////",
                    fill=False,
                    lw=0,
                    label="MC stat. unc.")

        if style.lower() == "summed":
            ax1.bar(x=bin_mids,
                    height=2 * np.sqrt(sum_w2),
                    width=self.bin_width,
                    bottom=sum_w - np.sqrt(sum_w2),
                    color=sum_color,
                    lw=0,
                    label="MC")

        ax1.errorbar(x=bin_mids,
                     y=hdata,
                     yerr=np.sqrt(hdata),
                     ls="",
                     marker=".",
                     color="black",
                     label=self._data_component.label)

        y_label = self._get_y_label(False, bin_width, evts_or_cand=ylabel)
        # ax1.legend(loc=0, bbox_to_anchor=(1,1))
        ax1.set_ylabel(y_label, plot_style.ylabel_pos)

        if draw_legend:
            if legend_inside:
                ax1.legend(frameon=False)
                ylims = ax1.get_ylim()
                ax1.set_ylim(ylims[0], 1.4 * ylims[1])
            else:
                ax1.legend(frameon=False, bbox_to_anchor=(1, 1))

        ax2.set_ylabel(r"$\frac{\mathrm{Data - MC}}{\mathrm{Data}}$")
        ax2.set_xlabel(self._variable.x_label, plot_style.xlabel_pos)
        ax2.set_ylim((-1, 1))

        try:
            uhdata = unp.uarray(hdata, np.sqrt(hdata))
            uhmc = unp.uarray(sum_w, np.sqrt(sum_w2))
            ratio = (uhdata - uhmc) / uhdata

            ax2.axhline(y=0, color=plot_style.KITColors.dark_grey, alpha=0.8)
            ax2.errorbar(bin_mids,
                         unp.nominal_values(ratio),
                         yerr=unp.std_devs(ratio),
                         ls="",
                         marker=".",
                         color=plot_style.KITColors.kit_black)
        except ZeroDivisionError:
            ax2.axhline(y=0, color=plot_style.KITColors.dark_grey, alpha=0.8)

        plt.subplots_adjust(hspace=0.08)
Пример #9
0
    def plot_on(self,
                ax: plt.axis,
                draw_legend: bool = True,
                legend_inside: bool = True,
                yaxis_scale=1.3,
                normed: bool = False,
                ylabel="Events",
                hide_labels: bool = False) -> plt.axis:
        """
        Plots the component on a given matplotlib.pyplot.axis

        :param ax: matplotlib.pyplot.axis where the histograms will be drawn
        on.
        :param draw_legend: Draw legend on axis if True.
        :param normed: If true the histograms are normalized.

        :return: matplotlib.pyplot.axis with histogram drawn on it
        """
        bin_edges, bin_mids, bin_width = self._get_bin_edges()

        self._bin_edges = bin_edges
        self._bin_mids = bin_mids
        self._bin_width = bin_width

        for component in self._mc_components['single']:
            if component.histtype == 'stepfilled':
                alpha = 0.6
                edge_color = 'black'
            else:
                edge_color = None
                alpha = 1.0
            ax.hist(x=component.data,
                    bins=bin_edges,
                    density=normed,
                    weights=component.weights,
                    histtype=component.histtype,
                    label=component.label,
                    edgecolor=edge_color
                    if edge_color is not None else component.color,
                    alpha=alpha,
                    lw=1.5,
                    ls=component.ls,
                    color=component.color)

        if not hide_labels:
            ax.set_xlabel(self._variable.x_label, plot_style.xlabel_pos)

            y_label = self._get_y_label(normed=normed,
                                        bin_width=bin_width,
                                        evts_or_cand=ylabel)
            ax.set_ylabel(y_label, plot_style.ylabel_pos)

        if draw_legend:
            if legend_inside:
                ax.legend(frameon=False)
                ylims = ax.get_ylim()
                ax.set_ylim(ylims[0], yaxis_scale * ylims[1])
            else:
                ax.legend(frameon=False, bbox_to_anchor=(1, 1))

        return ax
Пример #10
0
def word_group_visualization(
    transformed_word_embeddings: np.ndarray,
    words: np.ndarray,
    word_groups: dict,
    xlabel: str,
    ylabel: str,
    emphasis_words: list = None,
    alpha: float = 1,
    non_group_words_color: str = "#ccc",
    scatter_set_rasterized: bool = False,
    rasterization_threshold: int = 1000,
    ax: plt.axis = None,
    show_plot: bool = True,
) -> None:
    """
    Visualizes one or more word groups by plotting its word embeddings in 2D.

    Parameters
    ----------
    transformed_word_embeddings : np.ndarray
        Transformed word embeddings.
    words : np.ndarray
        Numpy array containing all words from vocabulary.
    word_groups : dict
        Dictionary containing word groups to visualize.
    xlabel : str
        X-axis label.
    ylabel : str
        Y-axis label.
    emphasis_words : list, optional
        List representing words to emphasize in the visualization (defaults to None).
        Entries can be either be strings (words) or tuples, consisting of the word, x-offset
        and y-offset.
    alpha : float
        Scatter plot alpha value (defaults to 1).
    non_group_words_color : str
        Color for words outside groups (defaults to #ccc).
    scatter_set_rasterized : bool
        Whether or not to enable rasterization on scatter plotting (defaults to False).
    rasterization_threshold : int
        The least number of data points to enable rasterization, given that
        `scatter_set_rasterized` is set to True (defaults to 1000).
    ax : plt.axis
        Axis (defaults to None).
    show_plot : bool
        Whether or not to call plt.show() (defaults to True).
    """
    # Filter and restrict words in word groups
    word_group_words_restricted = {}
    for group_key, group_data in word_groups.items():
        group_words = group_data["words"]
        group_words = np.array([word for word in group_words if word in words])
        group_words_indices = np.array(
            [np.where(words == word)[0][0] for word in group_words])
        group_word_embeddings = transformed_word_embeddings[
            group_words_indices]
        boundaries = group_data.get("boundaries", {})
        if boundaries.get("xmin") is None:
            boundaries["xmin"] = group_word_embeddings[:, 0].min()
        if boundaries.get("xmax") is None:
            boundaries["xmax"] = group_word_embeddings[:, 0].max()
        if boundaries.get("ymin") is None:
            boundaries["ymin"] = group_word_embeddings[:, 1].min()
        if boundaries.get("ymax") is None:
            boundaries["ymax"] = group_word_embeddings[:, 1].max()

        group_word_embeddings_boundaries_mask = [
            (boundaries["xmin"] <= word_vec[0] <= boundaries["xmax"])
            and (boundaries["ymin"] <= word_vec[1] <= boundaries["ymax"])
            for i, word_vec in enumerate(group_word_embeddings)
        ]
        word_group_words_restricted[group_key] = group_words[
            group_word_embeddings_boundaries_mask]

    # Find words not in groups
    words_not_in_groups_mask = [
        i for i, word in enumerate(words)
        for group_words in word_group_words_restricted.values()
        if word not in group_words
    ]

    if ax is None:
        _, ax = plt.subplots(figsize=(12, 7))
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    # Plot non-group words
    non_grp_scatter_handle = ax.scatter(
        x=transformed_word_embeddings[words_not_in_groups_mask][:, 0],
        y=transformed_word_embeddings[words_not_in_groups_mask][:, 1],
        s=10,
        alpha=alpha,
        c=non_group_words_color,
    )
    if (scatter_set_rasterized
            and len(words_not_in_groups_mask) >= rasterization_threshold):
        non_grp_scatter_handle.set_rasterized(True)

    # Plot group words
    for group_key, group_words in word_group_words_restricted.items():
        group_words_indices = np.array(
            [np.where(words == word)[0][0] for word in group_words])
        group_word_embeddings = transformed_word_embeddings[
            group_words_indices]

        grp_scatter_handle = ax.scatter(
            x=group_word_embeddings[:, 0],
            y=group_word_embeddings[:, 1],
            s=15,
            alpha=alpha,
            c=word_groups[group_key]["color"],
            label=word_groups[group_key]["label"],
        )
        if (scatter_set_rasterized
                and len(group_word_embeddings) >= rasterization_threshold):
            grp_scatter_handle.set_rasterized(True)

    # Visualize emphasized words
    if emphasis_words is not None:
        emphasis_words = [(entry, 0, 0) if type(entry) == str else entry
                          for entry in emphasis_words]
        for emphasis_word, x_offset, y_offset in emphasis_words:
            word_group_key = None
            for group_key, group_data in word_groups.items():
                if emphasis_word in group_data["words"]:
                    word_group_key = group_key
                    break
            if word_group_key is None:
                word_color = non_group_words_color
            else:
                word_color = word_groups[group_key]["color"]

            word_idx = [
                i for i, word in enumerate(words) if word == emphasis_word
            ][0]
            emphasis_scatter_handle = ax.scatter(
                x=transformed_word_embeddings[word_idx, 0],
                y=transformed_word_embeddings[word_idx, 1],
                s=40,
                alpha=alpha,
                c=word_color,
            )
            if (scatter_set_rasterized
                    and len(emphasis_words) >= rasterization_threshold):
                emphasis_scatter_handle.set_rasterized(True)

            # Annotate emphasis word with a text box
            offsetbox = TextArea(emphasis_word)
            ab = AnnotationBbox(
                offsetbox,
                tuple(transformed_word_embeddings[word_idx]),
                xybox=(x_offset, 40 + y_offset),
                xycoords="data",
                boxcoords="offset points",
                arrowprops=dict(arrowstyle="->", color="black", linewidth=2),
            )
            ax.add_artist(ab)

    ax.legend()
    if show_plot:
        plt.show()
Пример #11
0
def visualize_word_cluster_groups(
    transformed_word_embeddings: np.ndarray,
    words: np.ndarray,
    word_groups: dict,
    visualize_non_group_words: bool,
    xlabel: str,
    ylabel: str,
    non_group_words_color: str = "#ccc",
    ax: plt.axis = None,
    show_plot: bool = True,
    alpha: float = 1,
    interactive: bool = False,
) -> None:
    """
    Visualizes word cluster groups.

    Parameters
    ----------
    transformed_word_embeddings : np.ndarray
        Transformed word embeddings.
    words : np.ndarray
        Numpy array containing all words from vocabulary.
    word_groups : dict
        Dictionary containing word groups to visualize.
    visualize_non_group_words : bool
        Whether or not to visualize words outside word groups
    xlabel : str
        X-axis label
    ylabel : str
        Y-axis label
    non_group_words_color : str
        Color for words outside groups (defaults to #ccc)
    ax : plt.axis
        Matplotlib axis (defaults to None)
    show_plot : bool
        Whether or not to call plt.show() (defaults to True)
    alpha : float
        Scatter plot alpha value (defaults to 1)
    interactive : bool
        Whether or not to make the visualization interactive
        using Plotly (defaults to False).
    """
    if ax is None and not interactive:
        _, ax = plt.subplots(figsize=(12, 7))
    if interactive:
        fig = go.Figure(
            layout=dict(xaxis=dict(title=xlabel), yaxis=dict(title=ylabel)))

    if visualize_non_group_words:

        # Create boolean mask for words outside groups
        words_in_groups = []
        for group_name in word_groups.keys():
            words_in_groups.extend(word_groups[group_name]["words"])
        words_not_in_groups_mask = [
            word not in words_in_groups for word in words
        ]
        words_not_in_groups_sorted = [
            word for word in words if word not in words_in_groups
        ]

        # Plot words outside word group
        if interactive:
            fig.add_trace(
                go.Scatter(
                    x=transformed_word_embeddings[words_not_in_groups_mask][:,
                                                                            0],
                    y=transformed_word_embeddings[words_not_in_groups_mask][:,
                                                                            1],
                    mode="markers",
                    marker=dict(color=non_group_words_color),
                    hovertext=words_not_in_groups_sorted,
                    hoverinfo="x+y+text",
                    name="Non group words",
                    opacity=alpha,
                ))
        else:
            ax.scatter(
                x=transformed_word_embeddings[words_not_in_groups_mask][:, 0],
                y=transformed_word_embeddings[words_not_in_groups_mask][:, 1],
                c=non_group_words_color,
                alpha=alpha,
            )

    # Visualize words in groups
    for group_name, word_group in word_groups.items():
        words_in_group = word_group["words"]
        words_in_group_mask = [word in words_in_group for word in words]
        words_in_group_sorted = [
            word for word in words if word in words_in_group
        ]
        word_group_color = word_group["color"]

        # Plot words inside word group
        if interactive:
            fig.add_trace(
                go.Scatter(
                    x=transformed_word_embeddings[words_in_group_mask][:, 0],
                    y=transformed_word_embeddings[words_in_group_mask][:, 1],
                    mode="markers",
                    marker=dict(color=word_group_color),
                    hovertext=words_in_group_sorted,
                    hoverinfo="x+y+text",
                    name=f"Words in {group_name}",
                    opacity=alpha,
                ))
        else:
            ax.scatter(
                x=transformed_word_embeddings[words_in_group_mask][:, 0],
                y=transformed_word_embeddings[words_in_group_mask][:, 1],
                c=word_group_color,
                alpha=alpha,
            )

    if interactive:
        fig.show()
    else:
        ax_legends = ["Non group words"]
        ax_legends.extend([
            f"Words which are {group_name}"
            for group_name in word_groups.keys()
        ])
        ax.legend(ax_legends)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        if show_plot:
            plt.show()