def plot_all_trajs(ax: plt.axis, df: pd.DataFrame, radii: dict, plot_rings=False, relative=False): """Plot all of the trajectories in the input DataFrame on the given axis. Args: ax: Matplotlib axis on which we want to plot. df: DataFrame containing trajectory data. radii: Dictionary with inner and outer radii of N-ring. plot_rings: Boolean flag that plots outer and inner circles of N-ring if True. """ cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True) outer_rad, inner_rad = extract_radii(radii) kwargs = {'data': df, 'legend': False, 'palette': cmap, 'edgecolor': None, 's': 20} if relative is False: ax = sns.scatterplot(x='X_1', y='Y_1', hue='Trial', **kwargs) sns.scatterplot(x='X_2', y='Y_2', hue='Trial', **kwargs) else: ax = sns.scatterplot(x='rel_X', y='rel_Y', hue='Trial', **kwargs) if plot_rings is True: outer_circle = plt.Circle((0, 0), outer_rad, color='blue', fill=False) ax.add_artist(outer_circle) inner_circle = plt.Circle((0, 0), inner_rad, color='blue', fill=False) ax.add_artist(inner_circle) ax.axis('equal') # ax.set_xlim([-2.0, 2.0]) # ax.set_ylim([-2.0, 2.0]) ax.set_xlabel('relative x position' if relative is True else 'x position') ax.set_ylabel('relative y position' if relative is True else 'y position')
def word_group_visualization( transformed_word_embeddings: np.ndarray, words: np.ndarray, word_groups: dict, xlabel: str, ylabel: str, emphasis_words: list = None, alpha: float = 1, non_group_words_color: str = "#ccc", scatter_set_rasterized: bool = False, rasterization_threshold: int = 1000, ax: plt.axis = None, show_plot: bool = True, ) -> None: """ Visualizes one or more word groups by plotting its word embeddings in 2D. Parameters ---------- transformed_word_embeddings : np.ndarray Transformed word embeddings. words : np.ndarray Numpy array containing all words from vocabulary. word_groups : dict Dictionary containing word groups to visualize. xlabel : str X-axis label. ylabel : str Y-axis label. emphasis_words : list, optional List representing words to emphasize in the visualization (defaults to None). Entries can be either be strings (words) or tuples, consisting of the word, x-offset and y-offset. alpha : float Scatter plot alpha value (defaults to 1). non_group_words_color : str Color for words outside groups (defaults to #ccc). scatter_set_rasterized : bool Whether or not to enable rasterization on scatter plotting (defaults to False). rasterization_threshold : int The least number of data points to enable rasterization, given that `scatter_set_rasterized` is set to True (defaults to 1000). ax : plt.axis Axis (defaults to None). show_plot : bool Whether or not to call plt.show() (defaults to True). """ # Filter and restrict words in word groups word_group_words_restricted = {} for group_key, group_data in word_groups.items(): group_words = group_data["words"] group_words = np.array([word for word in group_words if word in words]) group_words_indices = np.array( [np.where(words == word)[0][0] for word in group_words]) group_word_embeddings = transformed_word_embeddings[ group_words_indices] boundaries = group_data.get("boundaries", {}) if boundaries.get("xmin") is None: boundaries["xmin"] = group_word_embeddings[:, 0].min() if boundaries.get("xmax") is None: boundaries["xmax"] = group_word_embeddings[:, 0].max() if boundaries.get("ymin") is None: boundaries["ymin"] = group_word_embeddings[:, 1].min() if boundaries.get("ymax") is None: boundaries["ymax"] = group_word_embeddings[:, 1].max() group_word_embeddings_boundaries_mask = [ (boundaries["xmin"] <= word_vec[0] <= boundaries["xmax"]) and (boundaries["ymin"] <= word_vec[1] <= boundaries["ymax"]) for i, word_vec in enumerate(group_word_embeddings) ] word_group_words_restricted[group_key] = group_words[ group_word_embeddings_boundaries_mask] # Find words not in groups words_not_in_groups_mask = [ i for i, word in enumerate(words) for group_words in word_group_words_restricted.values() if word not in group_words ] if ax is None: _, ax = plt.subplots(figsize=(12, 7)) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) # Plot non-group words non_grp_scatter_handle = ax.scatter( x=transformed_word_embeddings[words_not_in_groups_mask][:, 0], y=transformed_word_embeddings[words_not_in_groups_mask][:, 1], s=10, alpha=alpha, c=non_group_words_color, ) if (scatter_set_rasterized and len(words_not_in_groups_mask) >= rasterization_threshold): non_grp_scatter_handle.set_rasterized(True) # Plot group words for group_key, group_words in word_group_words_restricted.items(): group_words_indices = np.array( [np.where(words == word)[0][0] for word in group_words]) group_word_embeddings = transformed_word_embeddings[ group_words_indices] grp_scatter_handle = ax.scatter( x=group_word_embeddings[:, 0], y=group_word_embeddings[:, 1], s=15, alpha=alpha, c=word_groups[group_key]["color"], label=word_groups[group_key]["label"], ) if (scatter_set_rasterized and len(group_word_embeddings) >= rasterization_threshold): grp_scatter_handle.set_rasterized(True) # Visualize emphasized words if emphasis_words is not None: emphasis_words = [(entry, 0, 0) if type(entry) == str else entry for entry in emphasis_words] for emphasis_word, x_offset, y_offset in emphasis_words: word_group_key = None for group_key, group_data in word_groups.items(): if emphasis_word in group_data["words"]: word_group_key = group_key break if word_group_key is None: word_color = non_group_words_color else: word_color = word_groups[group_key]["color"] word_idx = [ i for i, word in enumerate(words) if word == emphasis_word ][0] emphasis_scatter_handle = ax.scatter( x=transformed_word_embeddings[word_idx, 0], y=transformed_word_embeddings[word_idx, 1], s=40, alpha=alpha, c=word_color, ) if (scatter_set_rasterized and len(emphasis_words) >= rasterization_threshold): emphasis_scatter_handle.set_rasterized(True) # Annotate emphasis word with a text box offsetbox = TextArea(emphasis_word) ab = AnnotationBbox( offsetbox, tuple(transformed_word_embeddings[word_idx]), xybox=(x_offset, 40 + y_offset), xycoords="data", boxcoords="offset points", arrowprops=dict(arrowstyle="->", color="black", linewidth=2), ) ax.add_artist(ab) ax.legend() if show_plot: plt.show()
def drawPointToPoint(self, target_axis: plt.axis, init_axis: plt.axis, px1: np.array, px2: np.array, **kwargs): target_axis.plot([px1[0]],[px1[1]], 'rx', markersize=10) con = ConnectionPatch(xyA=px1, xyB=px2, coordsA="data", coordsB="data", axesA=target_axis, axesB=init_axis, **kwargs) target_axis.add_artist(con)