def plot_all_trajs(ax: plt.axis, df: pd.DataFrame, radii: dict, 
                   plot_rings=False, relative=False):
    """Plot all of the trajectories in the input DataFrame on the given axis.

    Args:
        ax: Matplotlib axis on which we want to plot.
        df: DataFrame containing trajectory data.
        radii: Dictionary with inner and outer radii of N-ring.
        plot_rings: Boolean flag that plots outer and inner circles of N-ring if True.
    """

    cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True)
    outer_rad, inner_rad = extract_radii(radii)
    kwargs = {'data': df, 'legend': False, 'palette': cmap, 'edgecolor': None, 's': 20}
    
    if relative is False:
        ax = sns.scatterplot(x='X_1', y='Y_1', hue='Trial', **kwargs)
        sns.scatterplot(x='X_2', y='Y_2', hue='Trial', **kwargs)
    else:
        ax = sns.scatterplot(x='rel_X', y='rel_Y', hue='Trial', **kwargs)                    

    if plot_rings is True:
            
        outer_circle = plt.Circle((0, 0), outer_rad, color='blue', fill=False)
        ax.add_artist(outer_circle)

        inner_circle = plt.Circle((0, 0), inner_rad, color='blue', fill=False)
        ax.add_artist(inner_circle)

    ax.axis('equal')
    # ax.set_xlim([-2.0, 2.0])
    # ax.set_ylim([-2.0, 2.0])
    ax.set_xlabel('relative x position' if relative is True else 'x position')
    ax.set_ylabel('relative y position' if relative is True else 'y position')
Beispiel #2
0
    def plot_areas_society_progress(ax: plt.axis, time_array: List,
                                    society_snapshot: Dict,
                                    society_progress: Dict):

        previous_status = None
        ax.set_xlabel('time [days]')
        ax.set_ylabel('population percentage')
        for st in Status:
            society_progress[st.name].append(society_snapshot[st.name])

            if previous_status:
                lower_limit = society_progress[previous_status.name]
            else:
                lower_limit = [0]

            ax.fill_between(x=time_array,
                            y1=lower_limit,
                            y2=society_progress[st.name],
                            color=st.value,
                            label=st.name,
                            alpha=0.25)

            ax.text(x=time_array[-1],
                    y=1 / 2 * (lower_limit[-1] + society_snapshot[st.name]),
                    s=r"{0:.2f}".format(society_snapshot[st.name] -
                                        lower_limit[-1]),
                    size=10,
                    color=st.value)

            previous_status = st
Beispiel #3
0
    def _sub_plot_multiple_twin(self, axis:plt.axis, columns:tuple):
        '''
        Create a sub-plot for the given columns. 
        Create a twin x-axis with new y scale for the second column. 

        Parameters
        ----------
        ax : plt.axis
            figure axis
        column : tuple
            2 column names
        twin : bool
            
        '''
        # PLot first column
        axis.plot(self.m_data_df[columns[0]], label=columns[0], color='tab:blue',linestyle='-',linewidth=0.1)
        axis.tick_params(axis='y', labelcolor='tab:blue')
        axis.set_xlabel('Rollouts')
        axis.set_ylabel(columns[0])
        axis.set_title(columns[0] + ' ~ ' + columns[1])

        # PLot second column 
        ax = axis.twinx()
        ax.plot(self.m_data_df[columns[1]], label=columns[1], color='tab:red',linestyle='-',linewidth=1)
        ax.set_ylabel(columns[1])
        ax.tick_params(axis='y', labelcolor='tab:red')
Beispiel #4
0
    def plot_on(
        self,
        ax: plt.axis,
        draw_legend: bool = True,
        legend_inside: bool = True,
        legend_kwargs: dict = {},
        yaxis_scale=1.3,
        hide_labels: bool = False,
    ) -> plt.axis:

        for component in self.components:
            if component.style == 'point':
                ax.errorbar(
                    component.data.x_values,
                    component.data.y_values,
                    yerr=component.data.y_errors,
                    xerr=component.data.x_errors,
                    label=component.label,
                    color=component.color,
                    ls=component.ls,
                    marker=component.marker,
                )
            elif component.style == 'box':
                ax.bar(component.data.x_values,
                       2 * component.data.y_errors,
                       width=2 * component.data.x_errors,
                       bottom=component.data.y_values -
                       component.data.y_errors,
                       label=component.label,
                       color=component.color,
                       alpha=0.5)
            else:
                raise NotImplementedError(
                    "Options: point|box. If you require a new kind of plot, report a feature request"
                )

        if not hide_labels:
            ax.set_xlabel(self.variable.x_label, plot_style.xlabel_pos)
            ax.set_ylabel(self.variable.y_label, plot_style.ylabel_pos)

        if draw_legend:
            if legend_inside:
                ax.legend(frameon=False, **legend_kwargs)
                ylims = ax.get_ylim()
                ax.set_ylim(ylims[0], yaxis_scale * ylims[1])
            else:
                ax.legend(frameon=False,
                          bbox_to_anchor=(1, 1),
                          **legend_kwargs)

        return ax
Beispiel #5
0
def plot_cluster_sizes(cluster_labels: list,
                       ax: plt.axis = None) -> np.ndarray:
    """
    Plots cluster sizes using a histogram and returns a list of most frequent
    cluster sizes.

    Parameters
    ----------
    cluster_labels : list
        List of cluster labels
    ax : plt.axis
        Matplotlib axis (default None)

    Returns
    -------
    most_common_cluster_sizes : np.ndarray
        Numpy array containing the most common cluster sizes
    """
    if ax is None:
        _, ax = plt.subplots()

    # Print cluster size ratio (max / min)
    labels_unique, labels_counts = np.unique(cluster_labels,
                                             return_counts=True)
    cluster_sizes, cluster_size_counts = np.unique(labels_counts,
                                                   return_counts=True)

    num_clusters = len(labels_unique)
    max_cluster_size = max(labels_counts)
    min_cluster_size = min(labels_counts)
    cluster_size_ratio = max_cluster_size / min_cluster_size
    print(
        f"{num_clusters} clusters: max={max_cluster_size}, min={min_cluster_size}, ratio={cluster_size_ratio}"
    )

    # Plot distribution of cluster sizes
    sns.histplot(labels_counts, bins=max_cluster_size, ax=ax)
    ax.set_xlabel("Cluster size")
    ax.set_ylabel("Number of words in cluster")
    plt.show()

    # Sort cluster sizes by frequency
    most_common_cluster_sizes = cluster_sizes[np.argsort(cluster_size_counts)
                                              [::-1]]

    return most_common_cluster_sizes
Beispiel #6
0
    def plot_lines_society_progress(ax: plt.axis, time_array: List,
                                    society_snapshot: Dict,
                                    society_progress: Dict):

        ax.set_xlabel('time [days]')
        ax.set_ylabel('population percentage')
        for st in Status:
            society_progress[st.name].append(society_snapshot[st.name] /
                                             society_snapshot["Total"])
            ax.plot(time_array,
                    society_progress[st.name],
                    c=st.value,
                    ls='-',
                    label=st.name,
                    alpha=0.5)
            ax.text(x=time_array[-1],
                    y=society_progress[st.name][-1],
                    s=r"{0:.2f}".format(society_progress[st.name][-1]),
                    size=10,
                    color=st.value)
Beispiel #7
0
    def mpl_plot(self,
                 ax: plt.axis = None,
                 temp_unit: str = "GK",
                 **kwargs) -> plt.axis:
        """

        Parameters
        ----------
        ax : mpl.Axis
            mpl Axis to plot onto, if none provided get current axis is used.
        temp_unit : str {"Gk", "KeV"}
            Tempreture units for x axis.
        kwargs : key word args for mpl.plot
        """
        ax = ax or plt.gca()
        ax.set_title("Reaction Rate")
        ax.set_ylabel(r"Rate ($cm^3\;mol^{-1}\;sec^{-1}$)")

        if temp_unit is "GK":
            ax.set_xlabel("Temperature ($GK$)")

        return ax
    def plot_on(self,
                ax: plt.axis,
                ylabel="Events",
                draw_legend=True,
                legend_inside=True,
                hide_labels: bool = False):
        bin_edges, bin_mids, bin_width = self._get_bin_edges()

        self._bin_edges = bin_edges
        self._bin_mids = bin_mids
        self._bin_width = bin_width

        ax.hist(
            x=[comp.data for comp in self._mc_components['stacked']],
            bins=bin_edges,
            weights=[comp.weights for comp in self._mc_components['stacked']],
            stacked=True,
            edgecolor="black",
            lw=0.3,
            color=[comp.color for comp in self._mc_components['stacked']],
            label=[comp.label for comp in self._mc_components['stacked']],
            histtype='stepfilled')

        if not hide_labels:
            ax.set_xlabel(self._variable.x_label, plot_style.xlabel_pos)
            y_label = self._get_y_label(False, bin_width, ylabel)
            ax.set_ylabel(y_label, plot_style.ylabel_pos)
        if draw_legend:
            if legend_inside:
                ax.legend(frameon=False)
                ylims = ax.get_ylim()
                ax.set_ylim(ylims[0], 1.4 * ylims[1])

            else:
                ax.legend(frameon=False, bbox_to_anchor=(1, 1))

        return ax
Beispiel #9
0
    def _plot_summary_on_axis(
        self,
        ax: plt.axis,
        label_y_axis: bool,
        use_title: bool,
    ):
        """used to plot summary on multi-axis figure, or in standalone figure"""

        # axis
        if use_title:
            ax.set_title('Average', fontsize=configs.Figs.title_font_size)
            y_axis_label = self.y_axis_label
        else:
            y_axis_label = f'Average {self.y_axis_label}'
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.set_ylim(self.y_lims)

        # x-axis
        ax.set_xticks([self.last_step])
        ax.set_xticklabels([shorten_tick_label(self.last_step)],
                           fontsize=configs.Figs.tick_font_size)
        ax.set_xlabel(self.x_axis_label, fontsize=configs.Figs.ax_font_size)

        # y axis
        if label_y_axis:
            ax.set_ylabel(y_axis_label, fontsize=configs.Figs.ax_font_size)
            ax.set_yticks(self.y_ticks)
            ax.set_yticklabels(self.y_ticks,
                               fontsize=configs.Figs.tick_font_size)
        else:
            ax.set_ylabel('', fontsize=configs.Figs.ax_font_size)
            ax.set_yticks([])
            ax.set_yticklabels([], fontsize=configs.Figs.tick_font_size)

        # collect curves for each replication across all paradigms
        gn2rep2curves_by_pd = defaultdict(dict)
        for pd in self.pds:
            for gn, rep2curve in pd.group_name2rep2curve.items():
                for rep, curve in rep2curve.items():
                    # this curve is performance collapsed across template and for a unique rep and paradigm
                    gn2rep2curves_by_pd[gn].setdefault(rep, []).append(curve)

        # plot
        for gn, rep2curves_by_pd in gn2rep2curves_by_pd.items():
            # average across paradigms
            rep2curve_avg_across_pds = {
                rep: np.array(curves_by_pd).mean(axis=0)
                for rep, curves_by_pd in rep2curves_by_pd.items()
            }
            curves = np.array([
                rep2curve_avg_across_pds[rep]
                for rep in rep2curve_avg_across_pds
            ])  # one for each rep

            color = f'C{self.pds[0].group_names.index(gn)}'
            x = np.arange(0, self.last_step + self.step_size, self.step_size)

            # plot averages for BabyBERTa
            y = np.array(curves).mean(axis=0)
            ax.plot(x, y, linewidth=self.line_width, color=color)

            # plot average for RoBERTa-base
            y_roberta_base = np.repeat(
                np.mean(list(self.paradigm2roberta_base_accuracy.values())),
                len(x))
            ax.plot(x,
                    y_roberta_base,
                    linewidth=self.line_width,
                    **self.ax_kwargs_roberta_base)

            # plot average for frequency baseline
            y_baseline = np.repeat(
                np.mean(list(self.paradigm2baseline_accuracy.values())),
                len(x))
            ax.plot(x,
                    y_baseline,
                    linewidth=self.line_width,
                    **self.ax_kwargs_baseline)

            # plot the margin of error (shaded region)
            n = len(curves)
            h = sem(curves, axis=0) * t.ppf(
                (1 + self.confidence) / 2, n - 1)  # margin of error
            ax.fill_between(x, y + h, y - h, alpha=0.2, color=color)

            # printout
            if use_title:  # to prevent printing summary twice
                print(f'{gn} avg acc at step {self.last_step} = {y[-1]:.3f}')

        if use_title:
            y_roberta_base = np.mean(
                list(self.paradigm2roberta_base_accuracy.values()))
            print(
                f'roberta-base Liu2019 avg acc at step {self.last_step} = {y_roberta_base:.3f}'
            )
Beispiel #10
0
    def plot_on(
        self,
        ax1: plt.axis,
        ax2,
        style="stacked",
        ylabel="Events",
        sum_color=plot_style.KITColors.kit_purple,
        draw_legend: bool = True,
        legend_inside: bool = True,
    ):
        bin_edges, bin_mids, bin_width = self._get_bin_edges()

        self._bin_edges = bin_edges
        self._bin_mids = bin_mids
        self._bin_width = bin_width

        sum_w = np.sum(np.array([
            binned_statistic(comp.data,
                             comp.weights,
                             statistic="sum",
                             bins=bin_edges)[0]
            for comp in self._mc_components["MC"]
        ]),
                       axis=0)

        sum_w2 = np.sum(np.array([
            binned_statistic(comp.data,
                             comp.weights**2,
                             statistic="sum",
                             bins=bin_edges)[0]
            for comp in self._mc_components["MC"]
        ]),
                        axis=0)

        hdata, _ = np.histogram(self._data_component.data, bins=bin_edges)

        if style.lower() == "stacked":
            ax1.hist(
                x=[comp.data for comp in self._mc_components['MC']],
                bins=bin_edges,
                weights=[comp.weights for comp in self._mc_components['MC']],
                stacked=True,
                edgecolor="black",
                lw=0.3,
                color=[comp.color for comp in self._mc_components['MC']],
                label=[comp.label for comp in self._mc_components['MC']],
                histtype='stepfilled')

            ax1.bar(x=bin_mids,
                    height=2 * np.sqrt(sum_w2),
                    width=self.bin_width,
                    bottom=sum_w - np.sqrt(sum_w2),
                    color="black",
                    hatch="///////",
                    fill=False,
                    lw=0,
                    label="MC stat. unc.")

        if style.lower() == "summed":
            ax1.bar(x=bin_mids,
                    height=2 * np.sqrt(sum_w2),
                    width=self.bin_width,
                    bottom=sum_w - np.sqrt(sum_w2),
                    color=sum_color,
                    lw=0,
                    label="MC")

        ax1.errorbar(x=bin_mids,
                     y=hdata,
                     yerr=np.sqrt(hdata),
                     ls="",
                     marker=".",
                     color="black",
                     label=self._data_component.label)

        y_label = self._get_y_label(False, bin_width, evts_or_cand=ylabel)
        # ax1.legend(loc=0, bbox_to_anchor=(1,1))
        ax1.set_ylabel(y_label, plot_style.ylabel_pos)

        if draw_legend:
            if legend_inside:
                ax1.legend(frameon=False)
                ylims = ax1.get_ylim()
                ax1.set_ylim(ylims[0], 1.4 * ylims[1])
            else:
                ax1.legend(frameon=False, bbox_to_anchor=(1, 1))

        ax2.set_ylabel(r"$\frac{\mathrm{Data - MC}}{\mathrm{Data}}$")
        ax2.set_xlabel(self._variable.x_label, plot_style.xlabel_pos)
        ax2.set_ylim((-1, 1))

        try:
            uhdata = unp.uarray(hdata, np.sqrt(hdata))
            uhmc = unp.uarray(sum_w, np.sqrt(sum_w2))
            ratio = (uhdata - uhmc) / uhdata

            ax2.axhline(y=0, color=plot_style.KITColors.dark_grey, alpha=0.8)
            ax2.errorbar(bin_mids,
                         unp.nominal_values(ratio),
                         yerr=unp.std_devs(ratio),
                         ls="",
                         marker=".",
                         color=plot_style.KITColors.kit_black)
        except ZeroDivisionError:
            ax2.axhline(y=0, color=plot_style.KITColors.dark_grey, alpha=0.8)

        plt.subplots_adjust(hspace=0.08)
Beispiel #11
0
    def plot_on(self,
                ax: plt.axis,
                draw_legend: bool = True,
                legend_inside: bool = True,
                yaxis_scale=1.3,
                normed: bool = False,
                ylabel="Events",
                hide_labels: bool = False) -> plt.axis:
        """
        Plots the component on a given matplotlib.pyplot.axis

        :param ax: matplotlib.pyplot.axis where the histograms will be drawn
        on.
        :param draw_legend: Draw legend on axis if True.
        :param normed: If true the histograms are normalized.

        :return: matplotlib.pyplot.axis with histogram drawn on it
        """
        bin_edges, bin_mids, bin_width = self._get_bin_edges()

        self._bin_edges = bin_edges
        self._bin_mids = bin_mids
        self._bin_width = bin_width

        for component in self._mc_components['single']:
            if component.histtype == 'stepfilled':
                alpha = 0.6
                edge_color = 'black'
            else:
                edge_color = None
                alpha = 1.0
            ax.hist(x=component.data,
                    bins=bin_edges,
                    density=normed,
                    weights=component.weights,
                    histtype=component.histtype,
                    label=component.label,
                    edgecolor=edge_color
                    if edge_color is not None else component.color,
                    alpha=alpha,
                    lw=1.5,
                    ls=component.ls,
                    color=component.color)

        if not hide_labels:
            ax.set_xlabel(self._variable.x_label, plot_style.xlabel_pos)

            y_label = self._get_y_label(normed=normed,
                                        bin_width=bin_width,
                                        evts_or_cand=ylabel)
            ax.set_ylabel(y_label, plot_style.ylabel_pos)

        if draw_legend:
            if legend_inside:
                ax.legend(frameon=False)
                ylims = ax.get_ylim()
                ax.set_ylim(ylims[0], yaxis_scale * ylims[1])
            else:
                ax.legend(frameon=False, bbox_to_anchor=(1, 1))

        return ax
def plot_word_vectors(
    transformed_word_embeddings: np.ndarray,
    words: list,
    title: str,
    x_label: str,
    y_label: str,
    word_colors: np.ndarray = None,
    ax: plt.axis = None,
    show_plot: bool = True,
    interactive: bool = False,
    continuous_word_colors: bool = False,
) -> None:
    """
    Plots word vectors in transformed 2D space.

    Parameters
    ----------
    transformed_word_embeddings : np.ndarray
        Word embeddings transformed into 2D space.
    words : list
        List of words to plot
    title : str,
        Title to use for the plot.
    x_label : str,
        Label to use for the x-axis.
    y_label : str
        Label to use for the y-axis
    word_colors : np.ndarray, optional
        Numpy array consisting of unique labels for each word (i.e. cluster labels),
        (defaults to None).
    ax : plt.axis
        Matplotlib axis (defaults to None)
    show_plot : bool
        Whether or not to call plt.show() (defaults to True)
    interactive : bool
        Whether or not to make the visualization interactive
        using Plotly (defaults to False).
    continuous_word_colors : bool
        Whether or not to make the word color continuous (defaults to False).
    """
    if interactive:

        # Plot interactive plot
        fig = px.scatter(
            x=transformed_word_embeddings[:, 0],
            y=transformed_word_embeddings[:, 1],
            title=title,
            labels={
                "x": x_label,
                "y": y_label
            },
            color=[
                cluster_label if continuous_word_colors else str(cluster_label)
                for cluster_label in word_colors
            ] if word_colors is not None else None,
            hover_data={"word": words},
        )
        fig.show()
    else:
        if ax is None:
            _, ax = plt.subplots()
        ax.scatter(
            transformed_word_embeddings[:, 0],
            transformed_word_embeddings[:, 1],
            c=word_colors,
        )
        ax.set_title(title)
        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        if show_plot:
            plt.show()
Beispiel #13
0
def plot_cluster_metric_scores(
    metric_scores: list,
    hyperparameters: list,
    best_score_idx: int,
    metric_name: str,
    scatter: bool = True,
    set_xticks: bool = True,
    set_xtickslabels: bool = True,
    xtickslabels_rotation: int = 90,
    ax: plt.axis = None,
    xlabel: str = "Hyperparameters",
    xrange: range = None,
    show_plot: bool = True,
) -> None:
    """
    Plots internal cluster validation metric scores

    Parameters
    ----------
    metric_scores : list
        List of scores computed using metric
    hyperparameters : list
        List of hyperparameters used to compute the scores
    best_score_idx : int
        Best score index
    metric_name : str
        Name of the internal cluster validation metric
    scatter : bool
        Whether or not to scatter points (defaults to True)
    set_xticks : bool
        Whether or not to set the ticks on the x-axis
    set_xtickslabels : bool
        Whether or not to set the labels on the x-axis
    xtickslabels_rotation : int
        Sets the xticks labels rotation (defaults to 90), set_xtickslabels
        must be set to True to have an effect.
    ax : plt.axis
        Matplotlib axis (defaults to None)
    xlabel : str
        X-axis label (defaults to "Hyperparameters")
    xrange : range
        Range to use for the x-axis (default starts from 0 to)
    show_plot : bool
        Whether or not to call plt.show() (defaults to True)
    """
    if ax is None:
        _, ax = plt.subplots()
    if xrange is None:
        xrange = range(len(hyperparameters))
    ax.plot(xrange, metric_scores)
    if scatter:
        ax.scatter(xrange, metric_scores)
    ax.scatter(xrange[best_score_idx],
               metric_scores[best_score_idx],
               c="r",
               s=72,
               zorder=10)
    if set_xticks:
        ax.set_xticks(xrange)
    if set_xtickslabels:
        ax.set_xticklabels(hyperparameters,
                           rotation=xtickslabels_rotation,
                           ha="center")
    ax.set_xlabel(xlabel)
    ax.set_ylabel(f"{metric_name} score")
    if show_plot:
        plt.tight_layout()
        plt.show()
Beispiel #14
0
def word_group_visualization(
    transformed_word_embeddings: np.ndarray,
    words: np.ndarray,
    word_groups: dict,
    xlabel: str,
    ylabel: str,
    emphasis_words: list = None,
    alpha: float = 1,
    non_group_words_color: str = "#ccc",
    scatter_set_rasterized: bool = False,
    rasterization_threshold: int = 1000,
    ax: plt.axis = None,
    show_plot: bool = True,
) -> None:
    """
    Visualizes one or more word groups by plotting its word embeddings in 2D.

    Parameters
    ----------
    transformed_word_embeddings : np.ndarray
        Transformed word embeddings.
    words : np.ndarray
        Numpy array containing all words from vocabulary.
    word_groups : dict
        Dictionary containing word groups to visualize.
    xlabel : str
        X-axis label.
    ylabel : str
        Y-axis label.
    emphasis_words : list, optional
        List representing words to emphasize in the visualization (defaults to None).
        Entries can be either be strings (words) or tuples, consisting of the word, x-offset
        and y-offset.
    alpha : float
        Scatter plot alpha value (defaults to 1).
    non_group_words_color : str
        Color for words outside groups (defaults to #ccc).
    scatter_set_rasterized : bool
        Whether or not to enable rasterization on scatter plotting (defaults to False).
    rasterization_threshold : int
        The least number of data points to enable rasterization, given that
        `scatter_set_rasterized` is set to True (defaults to 1000).
    ax : plt.axis
        Axis (defaults to None).
    show_plot : bool
        Whether or not to call plt.show() (defaults to True).
    """
    # Filter and restrict words in word groups
    word_group_words_restricted = {}
    for group_key, group_data in word_groups.items():
        group_words = group_data["words"]
        group_words = np.array([word for word in group_words if word in words])
        group_words_indices = np.array(
            [np.where(words == word)[0][0] for word in group_words])
        group_word_embeddings = transformed_word_embeddings[
            group_words_indices]
        boundaries = group_data.get("boundaries", {})
        if boundaries.get("xmin") is None:
            boundaries["xmin"] = group_word_embeddings[:, 0].min()
        if boundaries.get("xmax") is None:
            boundaries["xmax"] = group_word_embeddings[:, 0].max()
        if boundaries.get("ymin") is None:
            boundaries["ymin"] = group_word_embeddings[:, 1].min()
        if boundaries.get("ymax") is None:
            boundaries["ymax"] = group_word_embeddings[:, 1].max()

        group_word_embeddings_boundaries_mask = [
            (boundaries["xmin"] <= word_vec[0] <= boundaries["xmax"])
            and (boundaries["ymin"] <= word_vec[1] <= boundaries["ymax"])
            for i, word_vec in enumerate(group_word_embeddings)
        ]
        word_group_words_restricted[group_key] = group_words[
            group_word_embeddings_boundaries_mask]

    # Find words not in groups
    words_not_in_groups_mask = [
        i for i, word in enumerate(words)
        for group_words in word_group_words_restricted.values()
        if word not in group_words
    ]

    if ax is None:
        _, ax = plt.subplots(figsize=(12, 7))
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    # Plot non-group words
    non_grp_scatter_handle = ax.scatter(
        x=transformed_word_embeddings[words_not_in_groups_mask][:, 0],
        y=transformed_word_embeddings[words_not_in_groups_mask][:, 1],
        s=10,
        alpha=alpha,
        c=non_group_words_color,
    )
    if (scatter_set_rasterized
            and len(words_not_in_groups_mask) >= rasterization_threshold):
        non_grp_scatter_handle.set_rasterized(True)

    # Plot group words
    for group_key, group_words in word_group_words_restricted.items():
        group_words_indices = np.array(
            [np.where(words == word)[0][0] for word in group_words])
        group_word_embeddings = transformed_word_embeddings[
            group_words_indices]

        grp_scatter_handle = ax.scatter(
            x=group_word_embeddings[:, 0],
            y=group_word_embeddings[:, 1],
            s=15,
            alpha=alpha,
            c=word_groups[group_key]["color"],
            label=word_groups[group_key]["label"],
        )
        if (scatter_set_rasterized
                and len(group_word_embeddings) >= rasterization_threshold):
            grp_scatter_handle.set_rasterized(True)

    # Visualize emphasized words
    if emphasis_words is not None:
        emphasis_words = [(entry, 0, 0) if type(entry) == str else entry
                          for entry in emphasis_words]
        for emphasis_word, x_offset, y_offset in emphasis_words:
            word_group_key = None
            for group_key, group_data in word_groups.items():
                if emphasis_word in group_data["words"]:
                    word_group_key = group_key
                    break
            if word_group_key is None:
                word_color = non_group_words_color
            else:
                word_color = word_groups[group_key]["color"]

            word_idx = [
                i for i, word in enumerate(words) if word == emphasis_word
            ][0]
            emphasis_scatter_handle = ax.scatter(
                x=transformed_word_embeddings[word_idx, 0],
                y=transformed_word_embeddings[word_idx, 1],
                s=40,
                alpha=alpha,
                c=word_color,
            )
            if (scatter_set_rasterized
                    and len(emphasis_words) >= rasterization_threshold):
                emphasis_scatter_handle.set_rasterized(True)

            # Annotate emphasis word with a text box
            offsetbox = TextArea(emphasis_word)
            ab = AnnotationBbox(
                offsetbox,
                tuple(transformed_word_embeddings[word_idx]),
                xybox=(x_offset, 40 + y_offset),
                xycoords="data",
                boxcoords="offset points",
                arrowprops=dict(arrowstyle="->", color="black", linewidth=2),
            )
            ax.add_artist(ab)

    ax.legend()
    if show_plot:
        plt.show()
Beispiel #15
0
def visualize_word_cluster_groups(
    transformed_word_embeddings: np.ndarray,
    words: np.ndarray,
    word_groups: dict,
    visualize_non_group_words: bool,
    xlabel: str,
    ylabel: str,
    non_group_words_color: str = "#ccc",
    ax: plt.axis = None,
    show_plot: bool = True,
    alpha: float = 1,
    interactive: bool = False,
) -> None:
    """
    Visualizes word cluster groups.

    Parameters
    ----------
    transformed_word_embeddings : np.ndarray
        Transformed word embeddings.
    words : np.ndarray
        Numpy array containing all words from vocabulary.
    word_groups : dict
        Dictionary containing word groups to visualize.
    visualize_non_group_words : bool
        Whether or not to visualize words outside word groups
    xlabel : str
        X-axis label
    ylabel : str
        Y-axis label
    non_group_words_color : str
        Color for words outside groups (defaults to #ccc)
    ax : plt.axis
        Matplotlib axis (defaults to None)
    show_plot : bool
        Whether or not to call plt.show() (defaults to True)
    alpha : float
        Scatter plot alpha value (defaults to 1)
    interactive : bool
        Whether or not to make the visualization interactive
        using Plotly (defaults to False).
    """
    if ax is None and not interactive:
        _, ax = plt.subplots(figsize=(12, 7))
    if interactive:
        fig = go.Figure(
            layout=dict(xaxis=dict(title=xlabel), yaxis=dict(title=ylabel)))

    if visualize_non_group_words:

        # Create boolean mask for words outside groups
        words_in_groups = []
        for group_name in word_groups.keys():
            words_in_groups.extend(word_groups[group_name]["words"])
        words_not_in_groups_mask = [
            word not in words_in_groups for word in words
        ]
        words_not_in_groups_sorted = [
            word for word in words if word not in words_in_groups
        ]

        # Plot words outside word group
        if interactive:
            fig.add_trace(
                go.Scatter(
                    x=transformed_word_embeddings[words_not_in_groups_mask][:,
                                                                            0],
                    y=transformed_word_embeddings[words_not_in_groups_mask][:,
                                                                            1],
                    mode="markers",
                    marker=dict(color=non_group_words_color),
                    hovertext=words_not_in_groups_sorted,
                    hoverinfo="x+y+text",
                    name="Non group words",
                    opacity=alpha,
                ))
        else:
            ax.scatter(
                x=transformed_word_embeddings[words_not_in_groups_mask][:, 0],
                y=transformed_word_embeddings[words_not_in_groups_mask][:, 1],
                c=non_group_words_color,
                alpha=alpha,
            )

    # Visualize words in groups
    for group_name, word_group in word_groups.items():
        words_in_group = word_group["words"]
        words_in_group_mask = [word in words_in_group for word in words]
        words_in_group_sorted = [
            word for word in words if word in words_in_group
        ]
        word_group_color = word_group["color"]

        # Plot words inside word group
        if interactive:
            fig.add_trace(
                go.Scatter(
                    x=transformed_word_embeddings[words_in_group_mask][:, 0],
                    y=transformed_word_embeddings[words_in_group_mask][:, 1],
                    mode="markers",
                    marker=dict(color=word_group_color),
                    hovertext=words_in_group_sorted,
                    hoverinfo="x+y+text",
                    name=f"Words in {group_name}",
                    opacity=alpha,
                ))
        else:
            ax.scatter(
                x=transformed_word_embeddings[words_in_group_mask][:, 0],
                y=transformed_word_embeddings[words_in_group_mask][:, 1],
                c=word_group_color,
                alpha=alpha,
            )

    if interactive:
        fig.show()
    else:
        ax_legends = ["Non group words"]
        ax_legends.extend([
            f"Words which are {group_name}"
            for group_name in word_groups.keys()
        ])
        ax.legend(ax_legends)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        if show_plot:
            plt.show()