def __getitem__(self, query: Union[str, List[str]]): """ Retreive a set of embeddings. Arguments: query: list of strings **Usage** ```python > from whatlies.language import CountVectorLanguage > lang = CountVectorLanguage(n_components=2, ngram_range=(1, 2), analyzer="char") > lang[['pizza', 'pizzas', 'firehouse', 'firehydrant']] ``` """ orig_str = isinstance(query, str) if orig_str: query = list(query) if self.fitted_manual: X = self.cv.transform(query) X_vec = self.svd.transform(X) else: X = self.cv.fit_transform(query) X_vec = self.svd.fit_transform(X) if orig_str: return Embedding(name=query[0], vector=X_vec[0]) return EmbeddingSet( *[Embedding(name=n, vector=v) for n, v in zip(query, X_vec)])
def __getitem__(self, query: Union[str, List[str]]): """ Retreive a single embedding or a set of embeddings. Arguments: query: single string or list of strings **Usage** ```python > from whatlies.language import GensimLanguage > lang = GensimLanguage("wordvectors.kv") > lang['computer'] > lang = GensimLanguage("wordvectors.kv") > lang[['computer', 'human', 'dog']] ``` """ if isinstance(query, str): if " " in query: return Embedding( query, np.sum([self[q].vector for q in query.split(" ")], axis=0)) try: vec = np.sum([self.kv[q] for q in query.split(" ")], axis=0) except KeyError: vec = np.zeros(self.kv.vector_size) return Embedding(query, vec) return EmbeddingSet(*[self[tok] for tok in query])
def __getitem__(self, query: Union[str, List[str]]): """ Retreive a set of embeddings. Arguments: query: list of strings **Usage** ```python from whatlies.language import CountVectorLanguage lang = CountVectorLanguage(n_components=2, ngram_range=(1, 2), analyzer="char") lang[['pizza', 'pizzas', 'firehouse', 'firehydrant']] ``` """ orig_str = isinstance(query, str) if orig_str: query = [query] if any([len(q) == 0 for q in query]): raise ValueError( "You've passed an empty string to the language model which is not allowed." ) if self.fitted_manual: X = self.cv.transform(query) X_vec = self.svd.transform(X) else: X = self.cv.fit_transform(query) X_vec = self.svd.fit_transform(X) if orig_str: return Embedding(name=query[0], vector=X_vec[0]) return EmbeddingSet( *[Embedding(name=n, vector=v) for n, v in zip(query, X_vec)])
def _get_embedding(self, query: str) -> Embedding: has_brackets = self._check_query_format(query) if has_brackets: start_idx, end_idx = self._get_context_pos(query) clean_query = query.replace("[", "").replace("]", "") vec = self.model(clean_query)[start_idx:end_idx].vector return Embedding(query, vec) return Embedding(query, self.model(query).vector)
def from_names_X(cls, names, X): """ Constructs an `EmbeddingSet` instance from the given embedding names and vectors. Arguments: names: an iterable containing the names of embeddings X: an iterable of 1D vectors, or a 2D numpy array; it should have the same length as `names` Usage: ```python from whatlies.embeddingset import EmbeddingSet names = ["foo", "bar", "buz"] vecs = [ [0.1, 0.3], [0.7, 0.2], [0.1, 0.9], ] emb = EmbeddingSet.from_names_X(names, vecs) """ X = np.array(X) if len(X) != len(names): raise ValueError( f"The number of given names ({len(names)}) and vectors ({len(X)}) should be the same." ) return cls({n: Embedding(n, v) for n, v in zip(names, X)})
def __getitem__(self, item): """ Retreive a single embedding or a set of embeddings. We retreive the sentence encoding that belongs to the entire utterance. Arguments: item: single string or list of strings **Usage** ```python from whatlies.language import DIETLanguage lang = DIETLanguage("path/to/model.tar.gz") lang[['hi', 'hello', 'greetings']] ``` """ if isinstance(item, str): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) msg = Message({"text": item}) for p in self.pipeline: p.process(msg) diagnostic_data = msg.as_dict_nlu()["diagnostic_data"] key_of_interest = [ k for k in diagnostic_data.keys() if "DIET" in k ][0] # It's assumed that the final token in the array here represents the __CLS__ token. # These are also known as the "sentence embeddings" tensors = diagnostic_data[key_of_interest]["text_transformed"] return Embedding(item, tensors[-1][-1]) if isinstance(item, list): return EmbeddingSet(*[self[i] for i in item]) raise ValueError(f"Item must be list of strings got {item}.")
def plot( self, kind: str = "arrow", x_axis: Union[int, str, Embedding] = None, y_axis: Union[int, str, Embedding] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, color: str = None, show_ops: bool = False, annot: bool = True, axis_option: Optional[str] = None, ): """ Makes (perhaps inferior) matplotlib plot. Consider using `plot_interactive` instead. Arguments: kind: what kind of plot to make, can be `scatter`, `arrow` or `text` x_axis: the x-axis to be used, must be given when dim > 2; if an integer, the corresponding dimension of embedding is used. y_axis: the y-axis to be used, must be given when dim > 2; if an integer, the corresponding dimension of embedding is used. x_label: an optional label used for x-axis; if not given, it is set based on value of `x_axis`. y_label: an optional label used for y-axis; if not given, it is set based on value of `y_axis`. title: an optional title for the plot. color: the color of the dots show_ops: setting to also show the applied operations, only works for `text` annot: should the points be annotated axis_option: a string which is passed as `option` argument to `matplotlib.pyplot.axis` in order to control axis properties (e.g. using `'equal'` make circles shown circular in the plot). This might be useful for preserving geometric relationships (e.g. orthogonality) in the generated plot. See `matplotlib.pyplot.axis` [documentation](https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.axis.html#matplotlib-pyplot-axis) for possible values and their description. """ if isinstance(x_axis, str): x_axis = self[x_axis] if isinstance(y_axis, str): y_axis = self[y_axis] embeddings = [] for emb in self.embeddings.values(): x_val, x_lab = emb._get_plot_axis_value_and_label(x_axis, dir="x") y_val, y_lab = emb._get_plot_axis_value_and_label(y_axis, dir="y") emb_plot = Embedding(name=emb.name, vector=[x_val, y_val], orig=emb.orig) embeddings.append(emb_plot) x_label = x_lab if x_label is None else x_label y_label = y_lab if y_label is None else y_label handle_2d_plot( embeddings, kind=kind, color=color, xlabel=x_label, ylabel=y_label, title=title, show_operations=show_ops, annot=annot, axis_option=axis_option, ) return self
def _get_embedding(self, query: str): features = np.array(self.model(query, padding=False)[0]) special_tokens_mask = self.model.tokenizer( query, return_special_tokens_mask=True, return_tensors="np")["special_tokens_mask"][0] vec = features[np.logical_not(special_tokens_mask)].sum(axis=0) return Embedding(query, vec)
def __getitem__( self, query: Union[str, List[str]]) -> Union[Embedding, EmbeddingSet]: """ Retreive a single embedding or a set of embeddings. Arguments: query: single string or list of strings **Usage** ```python > from whatlies.language import ConveRTLanguage > lang = ConveRTLanguage() > lang['bank'] > lang = ConveRTLanguage() > lang[['bank of the river', 'money on the bank', 'bank']] ``` """ if isinstance(query, str): query_tensor = tf.convert_to_tensor([query]) encoding = self.model(query_tensor) if self.signature == "encode_sequence": vec = encoding["sequence_encoding"].numpy().sum(axis=1)[0] else: vec = encoding["default"].numpy()[0] return Embedding(query, vec) return EmbeddingSet(*[self[tok] for tok in query])
def __getitem__(self, string): if isinstance(string, str): self._input_str_legal(string) start, end = _selected_idx_spacy(string) clean_string = string.replace("[", "").replace("]", "") vec = self.nlp(clean_string)[start:end].vector return Embedding(string, vec) return EmbeddingSet(*[self[tok] for tok in string])
def __getitem__(self, string): doc = self.nlp(string) vec = doc.vector start, end = 0, -1 split_string = string.split(" ") for idx, word in enumerate(split_string): if word[0] == "[": start = idx if word[-1] == "]": end = idx + 1 if start != 0: if end != -1: vec = doc[start:end].vector return Embedding(string, vec)
def __getitem__(self, query): """ Retreive a single embedding or a set of embeddings. Arguments: query: single string or list of strings **Usage** ```python > lang = SpacyLanguage("en_core_web_md") > lang['duck|NOUN'] > lang[['duck|NOUN'], ['duck|VERB']] ``` """ if isinstance(query, str): vec = self.s2v[query] return Embedding(query, vec) return EmbeddingSet(*[self[tok] for tok in query])
def __getitem__(self, query: Union[str, List[str]]): """ Retreive a single embedding or a set of embeddings. Arguments: query: single string or list of strings **Usage** ```python > lang = FasttextLanguage("cc.en.300.bin") > lang['python'] > lang[['python'], ['snake']] > lang[['nobody expects'], ['the spanish inquisition']] ``` """ if isinstance(query, str): vec = self.model.get_word_vector(query) return Embedding(query, vec) return EmbeddingSet(*[self[tok] for tok in query])
def __getitem__(self, query: Union[str, List[str]]): """ Retreive a single embedding or a set of embeddings. Depending on the spaCy model the strings can support multiple tokens of text but they can also use the Bert DSL. See the Language Options documentation: https://rasahq.github.io/whatlies/tutorial/languages/#bert-style. Arguments: query: single string or list of strings **Usage** ```python > lang = FasttextLanguage("cc.en.300.bin") > lang['python'] > lang[['python'], ['snake']] > lang[['nobody expects'], ['the spanish inquisition']] ``` """ if isinstance(query, str): self._input_str_legal(query) vec = self.ft.get_word_vector(query) return Embedding(query, vec) return EmbeddingSet(*[self[tok] for tok in query])
def average(self, name=None): """ Takes the average over all the embedding vectors in the embeddingset. Turns it into a new `Embedding`. Arguments: name: manually specify the name of the average embedding ```python from whatlies.embeddingset import EmbeddingSet foo = Embedding("foo", [1.0, 0.0]) bar = Embedding("bar", [0.0, 1.0]) emb = EmbeddingSet(foo, bar) emb.average().vector # [0.5, 0,5] emb.average(name="the-average").vector # [0.5, 0.5] ``` """ name = f"{self.name}.average()" if not name else name x = np.array([v.vector for v in self.embeddings.values()]) return Embedding(name, np.mean(x, axis=0))
def __getitem__(self, query: Union[str, List[str]]): """ Retreive a single embedding or a set of embeddings. Depending on the spaCy model the strings can support multiple tokens of text but they can also use the Bert DSL. See the Language Options documentation: https://rasahq.github.io/whatlies/tutorial/languages/#bert-style. Arguments: query: single string or list of strings **Usage** ```python > lang = SpacyLanguage("en_core_web_md") > lang['python'] > lang[['python'], ['snake']] > lang[['nobody expects'], ['the spanish inquisition']] ``` """ if isinstance(query, str): self._input_str_legal(query) start, end = _selected_idx_spacy(query) clean_string = query.replace("[", "").replace("]", "") vec = self.nlp(clean_string)[start:end].vector return Embedding(query, vec) return EmbeddingSet(*[self[tok] for tok in query])
def __getitem__(self, string): if isinstance(string, str): vec = self.s2v[string] return Embedding(string, vec) return EmbeddingSet(*[self[tok] for tok in string])
def plot_interactive_matrix( self, *axes: Union[int, str, Embedding], axes_metric: Optional[Union[str, Callable, Sequence]] = None, annot: bool = True, width: int = 200, height: int = 200, ): """ Makes highly interactive plot of the set of embeddings. Arguments: axes: the axes that we wish to plot; each could be either an integer, the name of an existing embedding, or an `Embedding` instance (default: `0, 1`). axes_metric: the metric used to project each embedding on the axes; only used when the corresponding axis is a string or an `Embedding` instance. It could be a string (`'cosine_similarity'`, `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input and returns a scalar value as output. To set different metrics for different axes, a list or a tuple of the same length as `axes` could be given. By default (`None`), normalized scalar projection (i.e. `>` operator) is used. annot: drawn points should be annotated width: width of the visual height: height of the visual **Usage** ```python from whatlies.language import SpacyLanguage from whatlies.transformers import Pca words = ["prince", "princess", "nurse", "doctor", "banker", "man", "woman", "cousin", "neice", "king", "queen", "dude", "guy", "gal", "fire", "dog", "cat", "mouse", "red", "bluee", "green", "yellow", "water", "person", "family", "brother", "sister"] lang = SpacyLanguage("en_core_web_sm") emb = lang[words] emb.transform(Pca(3)).plot_interactive_matrix(0, 1, 2) ``` """ # Set default value of axes, if not given. if len(axes) == 0: axes = [0, 1] if isinstance(axes_metric, (list, tuple)) and len(axes_metric) != len(axes): raise ValueError( f"The number of given axes metrics should be the same as the number of given axes. Got {len(axes)} axes vs. {len(axes_metric)} metrics." ) if not isinstance(axes_metric, (list, tuple)): axes_metric = [axes_metric] * len(axes) # Get values of each axis according to their type. axes_vals = {} X = self.to_X() for axis, metric in zip(axes, axes_metric): if isinstance(axis, int): vals = X[:, axis] axes_vals["Dimension " + str(axis)] = vals else: if isinstance(axis, str): axis = self[axis] metric = Embedding._get_plot_axis_metric_callable(metric) vals = self.compare_against(axis, mapping=metric) axes_vals[axis.name] = vals plot_df = pd.DataFrame(axes_vals) plot_df["name"] = [v.name for v in self.embeddings.values()] plot_df["original"] = [v.orig for v in self.embeddings.values()] axes_names = list(axes_vals.keys()) result = (alt.Chart(plot_df).mark_circle().encode( x=alt.X(alt.repeat("column"), type="quantitative"), y=alt.Y(alt.repeat("row"), type="quantitative"), tooltip=["name", "original"], )) if annot: text_stuff = result.mark_text(dx=-15, dy=3, color="black").encode( text="original", ) result = result + text_stuff result = (result.properties(width=width, height=height).repeat( row=axes_names[::-1], column=axes_names).interactive()) return result
def _get_embedding(self, query: str) -> Embedding: return Embedding(query, self.model(query).vector)
def plot_3d( self, x_axis: Union[int, str, Embedding] = 0, y_axis: Union[int, str, Embedding] = 1, z_axis: Union[int, str, Embedding] = 2, x_label: Optional[str] = None, y_label: Optional[str] = None, z_label: Optional[str] = None, title: Optional[str] = None, color: str = None, axis_metric: Optional[Union[str, Callable, Sequence]] = None, annot: bool = True, ): """ Creates a 3d visualisation of the embedding. Arguments: x_axis: the x-axis to be used, must be given when dim > 3; if an integer, the corresponding dimension of embedding is used. y_axis: the y-axis to be used, must be given when dim > 3; if an integer, the corresponding dimension of embedding is used. z_axis: the z-axis to be used, must be given when dim > 3; if an integer, the corresponding dimension of embedding is used. x_label: an optional label used for x-axis; if not given, it is set based on value of `x_axis`. y_label: an optional label used for y-axis; if not given, it is set based on value of `y_axis`. z_label: an optional label used for z-axis; if not given, it is set based on value of `z_axis`. title: an optional title for the plot. color: the property to user for the color axis_metric: the metric used to project each embedding on the axes; only used when the corresponding axis is a string or an `Embedding` instance. It could be a string (`'cosine_similarity'`, `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input and returns a scalar value as output. To set different metrics of the three different axes, you can pass a list/tuple of size three that describes the metrics you're interested in. By default (`None`), normalized scalar projection (i.e. `>` operator) is used. annot: drawn points should be annotated **Usage** ```python from whatlies.language import SpacyLanguage from whatlies.transformers import Pca words = ["prince", "princess", "nurse", "doctor", "banker", "man", "woman", "cousin", "neice", "king", "queen", "dude", "guy", "gal", "fire", "dog", "cat", "mouse", "red", "bluee", "green", "yellow", "water", "person", "family", "brother", "sister"] lang = SpacyLanguage("en_core_web_sm") emb = lang[words] emb.transform(Pca(3)).plot_3d(annot=True) emb.transform(Pca(3)).plot_3d("king", "dog", "red") emb.transform(Pca(3)).plot_3d("king", "dog", "red", axis_metric="cosine_distance") ``` """ if isinstance(x_axis, str): x_axis = self[x_axis] if isinstance(y_axis, str): y_axis = self[y_axis] if isinstance(z_axis, str): z_axis = self[z_axis] if isinstance(axis_metric, (list, tuple)): x_axis_metric = axis_metric[0] y_axis_metric = axis_metric[1] z_axis_metric = axis_metric[2] else: x_axis_metric = axis_metric y_axis_metric = axis_metric z_axis_metric = axis_metric # Determine axes values and labels if isinstance(x_axis, int): x_val = self.to_X()[:, x_axis] x_lab = "Dimension " + str(x_axis) else: x_axis_metric = Embedding._get_plot_axis_metric_callable( x_axis_metric) x_val = self.compare_against(x_axis, mapping=x_axis_metric) x_lab = x_axis.name x_lab = x_label if x_label is not None else x_lab if isinstance(y_axis, int): y_val = self.to_X()[:, y_axis] y_lab = "Dimension " + str(y_axis) else: y_axis_metric = Embedding._get_plot_axis_metric_callable( y_axis_metric) y_val = self.compare_against(y_axis, mapping=y_axis_metric) y_lab = y_axis.name y_lab = y_label if y_label is not None else y_lab if isinstance(z_axis, int): z_val = self.to_X()[:, z_axis] z_lab = "Dimension " + str(z_axis) else: z_axis_metric = Embedding._get_plot_axis_metric_callable( z_axis_metric) z_val = self.compare_against(z_axis, mapping=z_axis_metric) z_lab = z_axis.name z_lab = z_label if z_label is not None else z_lab # Save relevant information in a dataframe for plotting later. plot_df = pd.DataFrame({ "x_axis": x_val, "y_axis": y_val, "z_axis": z_val, "name": [v.name for v in self.embeddings.values()], "original": [v.orig for v in self.embeddings.values()], }) # Deal with the colors of the dots. if color: plot_df["color"] = [ getattr(v, color) if hasattr(v, color) else "" for v in self.embeddings.values() ] color_map = {k: v for v, k in enumerate(set(plot_df["color"]))} color_val = [ color_map[k] if not isinstance(k, float) else k for k in plot_df["color"] ] else: color_val = None ax = plt.axes(projection="3d") ax.scatter3D(plot_df["x_axis"], plot_df["y_axis"], plot_df["z_axis"], c=color_val, s=25) # Set the labels, titles, text annotations. ax.set_xlabel(x_lab) ax.set_ylabel(y_lab) ax.set_zlabel(z_lab) if annot: for i, row in plot_df.iterrows(): ax.text(row["x_axis"], row["y_axis"], row["z_axis"] + 0.05, row["original"]) if title: ax.set_title(label=title) return ax
def plot( self, kind: str = "arrow", x_axis: Union[int, str, Embedding] = 0, y_axis: Union[int, str, Embedding] = 1, axis_metric: Optional[Union[str, Callable, Sequence]] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, color: str = None, show_ops: bool = False, annot: bool = True, axis_option: Optional[str] = None, ): """ Makes (perhaps inferior) matplotlib plot. Consider using `plot_interactive` instead. Arguments: kind: what kind of plot to make, can be `scatter`, `arrow` or `text` x_axis: the x-axis to be used, must be given when dim > 2; if an integer, the corresponding dimension of embedding is used. y_axis: the y-axis to be used, must be given when dim > 2; if an integer, the corresponding dimension of embedding is used. axis_metric: the metric used to project each embedding on the axes; only used when the corresponding axis (i.e. `x_axis` or `y_axis`) is a string or an `Embedding` instance. It could be a string (`'cosine_similarity'`, `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input and returns a scalar value as output. To set different metrics for x- and y-axis, a list or a tuple of two elements could be given. By default (`None`), normalized scalar projection (i.e. `>` operator) is used. x_label: an optional label used for x-axis; if not given, it is set based on value of `x_axis`. y_label: an optional label used for y-axis; if not given, it is set based on value of `y_axis`. title: an optional title for the plot. color: the color of the dots show_ops: setting to also show the applied operations, only works for `text` annot: should the points be annotated axis_option: a string which is passed as `option` argument to `matplotlib.pyplot.axis` in order to control axis properties (e.g. using `'equal'` make circles shown circular in the plot). This might be useful for preserving geometric relationships (e.g. orthogonality) in the generated plot. See `matplotlib.pyplot.axis` [documentation](https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.axis.html#matplotlib-pyplot-axis) for possible values and their description. """ if isinstance(x_axis, str): x_axis = self[x_axis] if isinstance(y_axis, str): y_axis = self[y_axis] if isinstance(axis_metric, (list, tuple)): x_axis_metric = axis_metric[0] y_axis_metric = axis_metric[1] else: x_axis_metric = axis_metric y_axis_metric = axis_metric embeddings = [] for emb in self.embeddings.values(): x_val, x_lab = emb._get_plot_axis_value_and_label(x_axis, x_axis_metric, dir="x") y_val, y_lab = emb._get_plot_axis_value_and_label(y_axis, y_axis_metric, dir="y") emb_plot = Embedding(name=emb.name, vector=[x_val, y_val], orig=emb.orig) embeddings.append(emb_plot) x_label = x_lab if x_label is None else x_label y_label = y_lab if y_label is None else y_label handle_2d_plot( embeddings, kind=kind, color=color, xlabel=x_label, ylabel=y_label, title=title, show_operations=show_ops, annot=annot, axis_option=axis_option, ) return self
def _get_embedding(self, query: str) -> Embedding: vec = self.model([query]).numpy()[0] return Embedding(query, vec)
def plot_interactive( self, x_axis: Union[int, str, Embedding] = 0, y_axis: Union[int, str, Embedding] = 1, axis_metric: Optional[Union[str, Callable, Sequence]] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, annot: bool = True, color: Union[None, str] = None, ): """ Makes highly interactive plot of the set of embeddings. Arguments: x_axis: the x-axis to be used, must be given when dim > 2; if an integer, the corresponding dimension of embedding is used. y_axis: the y-axis to be used, must be given when dim > 2; if an integer, the corresponding dimension of embedding is used. axis_metric: the metric used to project each embedding on the axes; only used when the corresponding axis (i.e. `x_axis` or `y_axis`) is a string or an `Embedding` instance. It could be a string (`'cosine_similarity'`, `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input and returns a scalar value as output. To set different metrics for x- and y-axis, a list or a tuple of two elements could be given. By default (`None`), normalized scalar projection (i.e. `>` operator) is used. x_label: an optional label used for x-axis; if not given, it is set based on `x_axis` value. y_label: an optional label used for y-axis; if not given, it is set based on `y_axis` value. title: an optional title for the plot; if not given, it is set based on `x_axis` and `y_axis` values. annot: drawn points should be annotated color: a property that will be used for plotting **Usage** ```python from whatlies.language import SpacyLanguage words = ["prince", "princess", "nurse", "doctor", "banker", "man", "woman", "cousin", "neice", "king", "queen", "dude", "guy", "gal", "fire", "dog", "cat", "mouse", "red", "bluee", "green", "yellow", "water", "person", "family", "brother", "sister"] lang = SpacyLanguage("en_core_web_sm") emb = lang[words] emb.plot_interactive('man', 'woman') ``` """ if isinstance(x_axis, str): x_axis = self[x_axis] if isinstance(y_axis, str): y_axis = self[y_axis] if isinstance(axis_metric, (list, tuple)): x_axis_metric = axis_metric[0] y_axis_metric = axis_metric[1] else: x_axis_metric = axis_metric y_axis_metric = axis_metric # Determine axes values and labels if isinstance(x_axis, int): x_val = self.to_X()[:, x_axis] x_lab = "Dimension " + str(x_axis) else: x_axis_metric = Embedding._get_plot_axis_metric_callable( x_axis_metric) x_val = self.compare_against(x_axis, mapping=x_axis_metric) x_lab = x_axis.name if isinstance(y_axis, int): y_val = self.to_X()[:, y_axis] y_lab = "Dimension " + str(y_axis) else: y_axis_metric = Embedding._get_plot_axis_metric_callable( y_axis_metric) y_val = self.compare_against(y_axis, mapping=y_axis_metric) y_lab = y_axis.name x_label = x_label if x_label is not None else x_lab y_label = y_label if y_label is not None else y_lab title = title if title is not None else f"{x_lab} vs. {y_lab}" plot_df = pd.DataFrame({ "x_axis": x_val, "y_axis": y_val, "name": [v.name for v in self.embeddings.values()], "original": [v.orig for v in self.embeddings.values()], }) if color: plot_df[color] = [ getattr(v, color) if hasattr(v, color) else "" for v in self.embeddings.values() ] result = (alt.Chart(plot_df).mark_circle(size=60).encode( x=alt.X("x_axis", axis=alt.Axis(title=x_label)), y=alt.X("y_axis", axis=alt.Axis(title=y_label)), tooltip=["name", "original"], color=alt.Color(":N", legend=None) if not color else alt.Color(color), ).properties(title=title).interactive()) if annot: text = (alt.Chart(plot_df).mark_text(dx=-15, dy=3, color="black").encode( x="x_axis", y="y_axis", text="original", )) result = result + text return result