コード例 #1
0
ファイル: countvector_lang.py プロジェクト: mkaze/whatlies
    def __getitem__(self, query: Union[str, List[str]]):
        """
        Retreive a set of embeddings.

        Arguments:
            query: list of strings

        **Usage**

        ```python
        > from whatlies.language import CountVectorLanguage
        > lang = CountVectorLanguage(n_components=2, ngram_range=(1, 2), analyzer="char")
        > lang[['pizza', 'pizzas', 'firehouse', 'firehydrant']]
        ```
        """
        orig_str = isinstance(query, str)
        if orig_str:
            query = list(query)
        if self.fitted_manual:
            X = self.cv.transform(query)
            X_vec = self.svd.transform(X)
        else:
            X = self.cv.fit_transform(query)
            X_vec = self.svd.fit_transform(X)
        if orig_str:
            return Embedding(name=query[0], vector=X_vec[0])
        return EmbeddingSet(
            *[Embedding(name=n, vector=v) for n, v in zip(query, X_vec)])
コード例 #2
0
ファイル: _gensim_lang.py プロジェクト: nicoTrombon/whatlies
    def __getitem__(self, query: Union[str, List[str]]):
        """
        Retreive a single embedding or a set of embeddings.

        Arguments:
            query: single string or list of strings

        **Usage**
        ```python
        > from whatlies.language import GensimLanguage
        > lang = GensimLanguage("wordvectors.kv")
        > lang['computer']
        > lang = GensimLanguage("wordvectors.kv")
        > lang[['computer', 'human', 'dog']]
        ```
        """
        if isinstance(query, str):
            if " " in query:
                return Embedding(
                    query,
                    np.sum([self[q].vector for q in query.split(" ")], axis=0))
            try:
                vec = np.sum([self.kv[q] for q in query.split(" ")], axis=0)
            except KeyError:
                vec = np.zeros(self.kv.vector_size)
            return Embedding(query, vec)
        return EmbeddingSet(*[self[tok] for tok in query])
コード例 #3
0
    def __getitem__(self, query: Union[str, List[str]]):
        """
        Retreive a set of embeddings.

        Arguments:
            query: list of strings

        **Usage**

        ```python
        from whatlies.language import CountVectorLanguage
        lang = CountVectorLanguage(n_components=2, ngram_range=(1, 2), analyzer="char")
        lang[['pizza', 'pizzas', 'firehouse', 'firehydrant']]
        ```
        """
        orig_str = isinstance(query, str)
        if orig_str:
            query = [query]
        if any([len(q) == 0 for q in query]):
            raise ValueError(
                "You've passed an empty string to the language model which is not allowed."
            )
        if self.fitted_manual:
            X = self.cv.transform(query)
            X_vec = self.svd.transform(X)
        else:
            X = self.cv.fit_transform(query)
            X_vec = self.svd.fit_transform(X)
        if orig_str:
            return Embedding(name=query[0], vector=X_vec[0])
        return EmbeddingSet(
            *[Embedding(name=n, vector=v) for n, v in zip(query, X_vec)])
コード例 #4
0
ファイル: _spacy_lang.py プロジェクト: timvink/whatlies
 def _get_embedding(self, query: str) -> Embedding:
     has_brackets = self._check_query_format(query)
     if has_brackets:
         start_idx, end_idx = self._get_context_pos(query)
         clean_query = query.replace("[", "").replace("]", "")
         vec = self.model(clean_query)[start_idx:end_idx].vector
         return Embedding(query, vec)
     return Embedding(query, self.model(query).vector)
コード例 #5
0
    def from_names_X(cls, names, X):
        """
        Constructs an `EmbeddingSet` instance from the given embedding names and vectors.

        Arguments:
            names: an iterable containing the names of embeddings
            X: an iterable of 1D vectors, or a 2D numpy array; it should have the same length as `names`

        Usage:

        ```python
        from whatlies.embeddingset import EmbeddingSet

        names = ["foo", "bar", "buz"]
        vecs = [
            [0.1, 0.3],
            [0.7, 0.2],
            [0.1, 0.9],
        ]

        emb = EmbeddingSet.from_names_X(names, vecs)
        """
        X = np.array(X)
        if len(X) != len(names):
            raise ValueError(
                f"The number of given names ({len(names)}) and vectors ({len(X)}) should be the same."
            )
        return cls({n: Embedding(n, v) for n, v in zip(names, X)})
コード例 #6
0
ファイル: _diet_lang.py プロジェクト: RasaHQ/whatlies
    def __getitem__(self, item):
        """
        Retreive a single embedding or a set of embeddings. We retreive the sentence encoding that
        belongs to the entire utterance.

        Arguments:
            item: single string or list of strings

        **Usage**
        ```python
        from whatlies.language import DIETLanguage

        lang = DIETLanguage("path/to/model.tar.gz")
        lang[['hi', 'hello', 'greetings']]
        ```
        """
        if isinstance(item, str):
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=RuntimeWarning)
                msg = Message({"text": item})
                for p in self.pipeline:
                    p.process(msg)
                diagnostic_data = msg.as_dict_nlu()["diagnostic_data"]
                key_of_interest = [
                    k for k in diagnostic_data.keys() if "DIET" in k
                ][0]
                # It's assumed that the final token in the array here represents the __CLS__ token.
                # These are also known as the "sentence embeddings"
                tensors = diagnostic_data[key_of_interest]["text_transformed"]
                return Embedding(item, tensors[-1][-1])
        if isinstance(item, list):
            return EmbeddingSet(*[self[i] for i in item])
        raise ValueError(f"Item must be list of strings got {item}.")
コード例 #7
0
    def plot(
        self,
        kind: str = "arrow",
        x_axis: Union[int, str, Embedding] = None,
        y_axis: Union[int, str, Embedding] = None,
        x_label: Optional[str] = None,
        y_label: Optional[str] = None,
        title: Optional[str] = None,
        color: str = None,
        show_ops: bool = False,
        annot: bool = True,
        axis_option: Optional[str] = None,
    ):
        """
        Makes (perhaps inferior) matplotlib plot. Consider using `plot_interactive` instead.

        Arguments:
            kind: what kind of plot to make, can be `scatter`, `arrow` or `text`
            x_axis: the x-axis to be used, must be given when dim > 2; if an integer, the corresponding
                dimension of embedding is used.
            y_axis: the y-axis to be used, must be given when dim > 2; if an integer, the corresponding
                dimension of embedding is used.
            x_label: an optional label used for x-axis; if not given, it is set based on value of `x_axis`.
            y_label: an optional label used for y-axis; if not given, it is set based on value of `y_axis`.
            title: an optional title for the plot.
            color: the color of the dots
            show_ops: setting to also show the applied operations, only works for `text`
            annot: should the points be annotated
            axis_option: a string which is passed as `option` argument to `matplotlib.pyplot.axis` in order to control
                axis properties (e.g. using `'equal'` make circles shown circular in the plot). This might be useful
                for preserving geometric relationships (e.g. orthogonality) in the generated plot. See `matplotlib.pyplot.axis`
                [documentation](https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.axis.html#matplotlib-pyplot-axis)
                for possible values and their description.
        """
        if isinstance(x_axis, str):
            x_axis = self[x_axis]
        if isinstance(y_axis, str):
            y_axis = self[y_axis]
        embeddings = []
        for emb in self.embeddings.values():
            x_val, x_lab = emb._get_plot_axis_value_and_label(x_axis, dir="x")
            y_val, y_lab = emb._get_plot_axis_value_and_label(y_axis, dir="y")
            emb_plot = Embedding(name=emb.name,
                                 vector=[x_val, y_val],
                                 orig=emb.orig)
            embeddings.append(emb_plot)
        x_label = x_lab if x_label is None else x_label
        y_label = y_lab if y_label is None else y_label
        handle_2d_plot(
            embeddings,
            kind=kind,
            color=color,
            xlabel=x_label,
            ylabel=y_label,
            title=title,
            show_operations=show_ops,
            annot=annot,
            axis_option=axis_option,
        )
        return self
コード例 #8
0
 def _get_embedding(self, query: str):
     features = np.array(self.model(query, padding=False)[0])
     special_tokens_mask = self.model.tokenizer(
         query, return_special_tokens_mask=True,
         return_tensors="np")["special_tokens_mask"][0]
     vec = features[np.logical_not(special_tokens_mask)].sum(axis=0)
     return Embedding(query, vec)
コード例 #9
0
ファイル: _convert_lang.py プロジェクト: cirrushuet/whatlies
    def __getitem__(
            self, query: Union[str,
                               List[str]]) -> Union[Embedding, EmbeddingSet]:
        """
        Retreive a single embedding or a set of embeddings.

        Arguments:
            query: single string or list of strings

        **Usage**

        ```python
        > from whatlies.language import ConveRTLanguage
        > lang = ConveRTLanguage()
        > lang['bank']
        > lang = ConveRTLanguage()
        > lang[['bank of the river', 'money on the bank', 'bank']]
        ```
        """
        if isinstance(query, str):
            query_tensor = tf.convert_to_tensor([query])
            encoding = self.model(query_tensor)
            if self.signature == "encode_sequence":
                vec = encoding["sequence_encoding"].numpy().sum(axis=1)[0]
            else:
                vec = encoding["default"].numpy()[0]
            return Embedding(query, vec)
        return EmbeddingSet(*[self[tok] for tok in query])
コード例 #10
0
ファイル: language.py プロジェクト: Manikant92/whatlies
 def __getitem__(self, string):
     if isinstance(string, str):
         self._input_str_legal(string)
         start, end = _selected_idx_spacy(string)
         clean_string = string.replace("[", "").replace("]", "")
         vec = self.nlp(clean_string)[start:end].vector
         return Embedding(string, vec)
     return EmbeddingSet(*[self[tok] for tok in string])
コード例 #11
0
 def __getitem__(self, string):
     doc = self.nlp(string)
     vec = doc.vector
     start, end = 0, -1
     split_string = string.split(" ")
     for idx, word in enumerate(split_string):
         if word[0] == "[":
             start = idx
         if word[-1] == "]":
             end = idx + 1
     if start != 0:
         if end != -1:
             vec = doc[start:end].vector
     return Embedding(string, vec)
コード例 #12
0
    def __getitem__(self, query):
        """
        Retreive a single embedding or a set of embeddings.

        Arguments:
            query: single string or list of strings

        **Usage**
        ```python
        > lang = SpacyLanguage("en_core_web_md")
        > lang['duck|NOUN']
        > lang[['duck|NOUN'], ['duck|VERB']]
        ```
        """
        if isinstance(query, str):
            vec = self.s2v[query]
            return Embedding(query, vec)
        return EmbeddingSet(*[self[tok] for tok in query])
コード例 #13
0
ファイル: _floret_lang.py プロジェクト: RasaHQ/whatlies
    def __getitem__(self, query: Union[str, List[str]]):
        """
        Retreive a single embedding or a set of embeddings.

        Arguments:
            query: single string or list of strings

        **Usage**
        ```python
        > lang = FasttextLanguage("cc.en.300.bin")
        > lang['python']
        > lang[['python'], ['snake']]
        > lang[['nobody expects'], ['the spanish inquisition']]
        ```
        """
        if isinstance(query, str):
            vec = self.model.get_word_vector(query)
            return Embedding(query, vec)
        return EmbeddingSet(*[self[tok] for tok in query])
コード例 #14
0
ファイル: fasttext_lang.py プロジェクト: dmccreary/whatlies
    def __getitem__(self, query: Union[str, List[str]]):
        """
        Retreive a single embedding or a set of embeddings. Depending on the spaCy model
        the strings can support multiple tokens of text but they can also use the Bert DSL.
        See the Language Options documentation: https://rasahq.github.io/whatlies/tutorial/languages/#bert-style.

        Arguments:
            query: single string or list of strings

        **Usage**
        ```python
        > lang = FasttextLanguage("cc.en.300.bin")
        > lang['python']
        > lang[['python'], ['snake']]
        > lang[['nobody expects'], ['the spanish inquisition']]
        ```
        """
        if isinstance(query, str):
            self._input_str_legal(query)
            vec = self.ft.get_word_vector(query)
            return Embedding(query, vec)
        return EmbeddingSet(*[self[tok] for tok in query])
コード例 #15
0
    def average(self, name=None):
        """
        Takes the average over all the embedding vectors in the embeddingset. Turns it into
        a new `Embedding`.

        Arguments:
            name: manually specify the name of the average embedding

        ```python
        from whatlies.embeddingset import EmbeddingSet

        foo = Embedding("foo", [1.0, 0.0])
        bar = Embedding("bar", [0.0, 1.0])
        emb = EmbeddingSet(foo, bar)

        emb.average().vector                   # [0.5, 0,5]
        emb.average(name="the-average").vector # [0.5, 0.5]
        ```
        """
        name = f"{self.name}.average()" if not name else name
        x = np.array([v.vector for v in self.embeddings.values()])
        return Embedding(name, np.mean(x, axis=0))
コード例 #16
0
    def __getitem__(self, query: Union[str, List[str]]):
        """
        Retreive a single embedding or a set of embeddings. Depending on the spaCy model
        the strings can support multiple tokens of text but they can also use the Bert DSL.
        See the Language Options documentation: https://rasahq.github.io/whatlies/tutorial/languages/#bert-style.

        Arguments:
            query: single string or list of strings

        **Usage**
        ```python
        > lang = SpacyLanguage("en_core_web_md")
        > lang['python']
        > lang[['python'], ['snake']]
        > lang[['nobody expects'], ['the spanish inquisition']]
        ```
        """
        if isinstance(query, str):
            self._input_str_legal(query)
            start, end = _selected_idx_spacy(query)
            clean_string = query.replace("[", "").replace("]", "")
            vec = self.nlp(clean_string)[start:end].vector
            return Embedding(query, vec)
        return EmbeddingSet(*[self[tok] for tok in query])
コード例 #17
0
ファイル: language.py プロジェクト: Manikant92/whatlies
 def __getitem__(self, string):
     if isinstance(string, str):
         vec = self.s2v[string]
         return Embedding(string, vec)
     return EmbeddingSet(*[self[tok] for tok in string])
コード例 #18
0
ファイル: embeddingset.py プロジェクト: timvink/whatlies
    def plot_interactive_matrix(
        self,
        *axes: Union[int, str, Embedding],
        axes_metric: Optional[Union[str, Callable, Sequence]] = None,
        annot: bool = True,
        width: int = 200,
        height: int = 200,
    ):
        """
        Makes highly interactive plot of the set of embeddings.

        Arguments:
            axes: the axes that we wish to plot; each could be either an integer, the name of
                an existing embedding, or an `Embedding` instance (default: `0, 1`).
            axes_metric: the metric used to project each embedding on the axes; only used when the corresponding
                axis is a string or an `Embedding` instance. It could be a string (`'cosine_similarity'`,
                `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input and
                returns a scalar value as output. To set different metrics for different axes, a list or a tuple of
                the same length as `axes` could be given. By default (`None`), normalized scalar projection (i.e. `>` operator) is used.
            annot: drawn points should be annotated
            width: width of the visual
            height: height of the visual

        **Usage**

        ```python
        from whatlies.language import SpacyLanguage
        from whatlies.transformers import Pca

        words = ["prince", "princess", "nurse", "doctor", "banker", "man", "woman",
                 "cousin", "neice", "king", "queen", "dude", "guy", "gal", "fire",
                 "dog", "cat", "mouse", "red", "bluee", "green", "yellow", "water",
                 "person", "family", "brother", "sister"]

        lang = SpacyLanguage("en_core_web_sm")
        emb = lang[words]

        emb.transform(Pca(3)).plot_interactive_matrix(0, 1, 2)
        ```
        """
        # Set default value of axes, if not given.
        if len(axes) == 0:
            axes = [0, 1]

        if isinstance(axes_metric,
                      (list, tuple)) and len(axes_metric) != len(axes):
            raise ValueError(
                f"The number of given axes metrics should be the same as the number of given axes. Got {len(axes)} axes vs. {len(axes_metric)} metrics."
            )
        if not isinstance(axes_metric, (list, tuple)):
            axes_metric = [axes_metric] * len(axes)

        # Get values of each axis according to their type.
        axes_vals = {}
        X = self.to_X()
        for axis, metric in zip(axes, axes_metric):
            if isinstance(axis, int):
                vals = X[:, axis]
                axes_vals["Dimension " + str(axis)] = vals
            else:
                if isinstance(axis, str):
                    axis = self[axis]
                metric = Embedding._get_plot_axis_metric_callable(metric)
                vals = self.compare_against(axis, mapping=metric)
                axes_vals[axis.name] = vals

        plot_df = pd.DataFrame(axes_vals)
        plot_df["name"] = [v.name for v in self.embeddings.values()]
        plot_df["original"] = [v.orig for v in self.embeddings.values()]
        axes_names = list(axes_vals.keys())

        result = (alt.Chart(plot_df).mark_circle().encode(
            x=alt.X(alt.repeat("column"), type="quantitative"),
            y=alt.Y(alt.repeat("row"), type="quantitative"),
            tooltip=["name", "original"],
        ))
        if annot:
            text_stuff = result.mark_text(dx=-15, dy=3, color="black").encode(
                text="original", )
            result = result + text_stuff

        result = (result.properties(width=width, height=height).repeat(
            row=axes_names[::-1], column=axes_names).interactive())

        return result
コード例 #19
0
ファイル: _spacy_lang.py プロジェクト: RasaHQ/whatlies
 def _get_embedding(self, query: str) -> Embedding:
     return Embedding(query, self.model(query).vector)
コード例 #20
0
ファイル: embeddingset.py プロジェクト: timvink/whatlies
    def plot_3d(
        self,
        x_axis: Union[int, str, Embedding] = 0,
        y_axis: Union[int, str, Embedding] = 1,
        z_axis: Union[int, str, Embedding] = 2,
        x_label: Optional[str] = None,
        y_label: Optional[str] = None,
        z_label: Optional[str] = None,
        title: Optional[str] = None,
        color: str = None,
        axis_metric: Optional[Union[str, Callable, Sequence]] = None,
        annot: bool = True,
    ):
        """
        Creates a 3d visualisation of the embedding.

        Arguments:
            x_axis: the x-axis to be used, must be given when dim > 3; if an integer, the corresponding
                dimension of embedding is used.
            y_axis: the y-axis to be used, must be given when dim > 3; if an integer, the corresponding
                dimension of embedding is used.
            z_axis: the z-axis to be used, must be given when dim > 3; if an integer, the corresponding
                dimension of embedding is used.
            x_label: an optional label used for x-axis; if not given, it is set based on value of `x_axis`.
            y_label: an optional label used for y-axis; if not given, it is set based on value of `y_axis`.
            z_label: an optional label used for z-axis; if not given, it is set based on value of `z_axis`.
            title: an optional title for the plot.
            color: the property to user for the color
            axis_metric: the metric used to project each embedding on the axes; only used when the corresponding
                axis is a string or an `Embedding` instance. It could be a string (`'cosine_similarity'`,
                `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input and
                returns a scalar value as output. To set different metrics of the three different axes,
                you can pass a list/tuple of size three that describes the metrics you're interested in.
                By default (`None`), normalized scalar projection (i.e. `>` operator) is used.
            annot: drawn points should be annotated

        **Usage**

        ```python
        from whatlies.language import SpacyLanguage
        from whatlies.transformers import Pca

        words = ["prince", "princess", "nurse", "doctor", "banker", "man", "woman",
                 "cousin", "neice", "king", "queen", "dude", "guy", "gal", "fire",
                 "dog", "cat", "mouse", "red", "bluee", "green", "yellow", "water",
                 "person", "family", "brother", "sister"]

        lang = SpacyLanguage("en_core_web_sm")
        emb = lang[words]

        emb.transform(Pca(3)).plot_3d(annot=True)
        emb.transform(Pca(3)).plot_3d("king", "dog", "red")
        emb.transform(Pca(3)).plot_3d("king", "dog", "red", axis_metric="cosine_distance")
        ```
        """
        if isinstance(x_axis, str):
            x_axis = self[x_axis]
        if isinstance(y_axis, str):
            y_axis = self[y_axis]
        if isinstance(z_axis, str):
            z_axis = self[z_axis]

        if isinstance(axis_metric, (list, tuple)):
            x_axis_metric = axis_metric[0]
            y_axis_metric = axis_metric[1]
            z_axis_metric = axis_metric[2]
        else:
            x_axis_metric = axis_metric
            y_axis_metric = axis_metric
            z_axis_metric = axis_metric

        # Determine axes values and labels
        if isinstance(x_axis, int):
            x_val = self.to_X()[:, x_axis]
            x_lab = "Dimension " + str(x_axis)
        else:
            x_axis_metric = Embedding._get_plot_axis_metric_callable(
                x_axis_metric)
            x_val = self.compare_against(x_axis, mapping=x_axis_metric)
            x_lab = x_axis.name
        x_lab = x_label if x_label is not None else x_lab

        if isinstance(y_axis, int):
            y_val = self.to_X()[:, y_axis]
            y_lab = "Dimension " + str(y_axis)
        else:
            y_axis_metric = Embedding._get_plot_axis_metric_callable(
                y_axis_metric)
            y_val = self.compare_against(y_axis, mapping=y_axis_metric)
            y_lab = y_axis.name
        y_lab = y_label if y_label is not None else y_lab

        if isinstance(z_axis, int):
            z_val = self.to_X()[:, z_axis]
            z_lab = "Dimension " + str(z_axis)
        else:
            z_axis_metric = Embedding._get_plot_axis_metric_callable(
                z_axis_metric)
            z_val = self.compare_against(z_axis, mapping=z_axis_metric)
            z_lab = z_axis.name
        z_lab = z_label if z_label is not None else z_lab

        # Save relevant information in a dataframe for plotting later.
        plot_df = pd.DataFrame({
            "x_axis":
            x_val,
            "y_axis":
            y_val,
            "z_axis":
            z_val,
            "name": [v.name for v in self.embeddings.values()],
            "original": [v.orig for v in self.embeddings.values()],
        })

        # Deal with the colors of the dots.
        if color:
            plot_df["color"] = [
                getattr(v, color) if hasattr(v, color) else ""
                for v in self.embeddings.values()
            ]

            color_map = {k: v for v, k in enumerate(set(plot_df["color"]))}
            color_val = [
                color_map[k] if not isinstance(k, float) else k
                for k in plot_df["color"]
            ]
        else:
            color_val = None

        ax = plt.axes(projection="3d")
        ax.scatter3D(plot_df["x_axis"],
                     plot_df["y_axis"],
                     plot_df["z_axis"],
                     c=color_val,
                     s=25)

        # Set the labels, titles, text annotations.
        ax.set_xlabel(x_lab)
        ax.set_ylabel(y_lab)
        ax.set_zlabel(z_lab)

        if annot:
            for i, row in plot_df.iterrows():
                ax.text(row["x_axis"], row["y_axis"], row["z_axis"] + 0.05,
                        row["original"])
        if title:
            ax.set_title(label=title)
        return ax
コード例 #21
0
ファイル: embeddingset.py プロジェクト: timvink/whatlies
    def plot(
        self,
        kind: str = "arrow",
        x_axis: Union[int, str, Embedding] = 0,
        y_axis: Union[int, str, Embedding] = 1,
        axis_metric: Optional[Union[str, Callable, Sequence]] = None,
        x_label: Optional[str] = None,
        y_label: Optional[str] = None,
        title: Optional[str] = None,
        color: str = None,
        show_ops: bool = False,
        annot: bool = True,
        axis_option: Optional[str] = None,
    ):
        """
        Makes (perhaps inferior) matplotlib plot. Consider using `plot_interactive` instead.

        Arguments:
            kind: what kind of plot to make, can be `scatter`, `arrow` or `text`
            x_axis: the x-axis to be used, must be given when dim > 2; if an integer, the corresponding
                dimension of embedding is used.
            y_axis: the y-axis to be used, must be given when dim > 2; if an integer, the corresponding
                dimension of embedding is used.
            axis_metric: the metric used to project each embedding on the axes; only used when the corresponding
                axis (i.e. `x_axis` or `y_axis`) is a string or an `Embedding` instance. It could be a string
                (`'cosine_similarity'`, `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input
                and returns a scalar value as output. To set different metrics for x- and y-axis, a list or a tuple of
                two elements could be given. By default (`None`), normalized scalar projection (i.e. `>` operator) is used.
            x_label: an optional label used for x-axis; if not given, it is set based on value of `x_axis`.
            y_label: an optional label used for y-axis; if not given, it is set based on value of `y_axis`.
            title: an optional title for the plot.
            color: the color of the dots
            show_ops: setting to also show the applied operations, only works for `text`
            annot: should the points be annotated
            axis_option: a string which is passed as `option` argument to `matplotlib.pyplot.axis` in order to control
                axis properties (e.g. using `'equal'` make circles shown circular in the plot). This might be useful
                for preserving geometric relationships (e.g. orthogonality) in the generated plot. See `matplotlib.pyplot.axis`
                [documentation](https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.axis.html#matplotlib-pyplot-axis)
                for possible values and their description.
        """
        if isinstance(x_axis, str):
            x_axis = self[x_axis]
        if isinstance(y_axis, str):
            y_axis = self[y_axis]

        if isinstance(axis_metric, (list, tuple)):
            x_axis_metric = axis_metric[0]
            y_axis_metric = axis_metric[1]
        else:
            x_axis_metric = axis_metric
            y_axis_metric = axis_metric

        embeddings = []
        for emb in self.embeddings.values():
            x_val, x_lab = emb._get_plot_axis_value_and_label(x_axis,
                                                              x_axis_metric,
                                                              dir="x")
            y_val, y_lab = emb._get_plot_axis_value_and_label(y_axis,
                                                              y_axis_metric,
                                                              dir="y")
            emb_plot = Embedding(name=emb.name,
                                 vector=[x_val, y_val],
                                 orig=emb.orig)
            embeddings.append(emb_plot)
        x_label = x_lab if x_label is None else x_label
        y_label = y_lab if y_label is None else y_label
        handle_2d_plot(
            embeddings,
            kind=kind,
            color=color,
            xlabel=x_label,
            ylabel=y_label,
            title=title,
            show_operations=show_ops,
            annot=annot,
            axis_option=axis_option,
        )
        return self
コード例 #22
0
ファイル: _tfhub_lang.py プロジェクト: cirrushuet/whatlies
 def _get_embedding(self, query: str) -> Embedding:
     vec = self.model([query]).numpy()[0]
     return Embedding(query, vec)
コード例 #23
0
ファイル: embeddingset.py プロジェクト: timvink/whatlies
    def plot_interactive(
        self,
        x_axis: Union[int, str, Embedding] = 0,
        y_axis: Union[int, str, Embedding] = 1,
        axis_metric: Optional[Union[str, Callable, Sequence]] = None,
        x_label: Optional[str] = None,
        y_label: Optional[str] = None,
        title: Optional[str] = None,
        annot: bool = True,
        color: Union[None, str] = None,
    ):
        """
        Makes highly interactive plot of the set of embeddings.

        Arguments:
            x_axis: the x-axis to be used, must be given when dim > 2; if an integer, the corresponding
                dimension of embedding is used.
            y_axis: the y-axis to be used, must be given when dim > 2; if an integer, the corresponding
                dimension of embedding is used.
            axis_metric: the metric used to project each embedding on the axes; only used when the corresponding
                axis (i.e. `x_axis` or `y_axis`) is a string or an `Embedding` instance. It could be a string
                (`'cosine_similarity'`, `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input
                and returns a scalar value as output. To set different metrics for x- and y-axis, a list or a tuple of
                two elements could be given. By default (`None`), normalized scalar projection (i.e. `>` operator) is used.
            x_label: an optional label used for x-axis; if not given, it is set based on `x_axis` value.
            y_label: an optional label used for y-axis; if not given, it is set based on `y_axis` value.
            title: an optional title for the plot; if not given, it is set based on `x_axis` and `y_axis` values.
            annot: drawn points should be annotated
            color: a property that will be used for plotting

        **Usage**

        ```python
        from whatlies.language import SpacyLanguage

        words = ["prince", "princess", "nurse", "doctor", "banker", "man", "woman",
                 "cousin", "neice", "king", "queen", "dude", "guy", "gal", "fire",
                 "dog", "cat", "mouse", "red", "bluee", "green", "yellow", "water",
                 "person", "family", "brother", "sister"]

        lang = SpacyLanguage("en_core_web_sm")
        emb = lang[words]

        emb.plot_interactive('man', 'woman')
        ```
        """
        if isinstance(x_axis, str):
            x_axis = self[x_axis]
        if isinstance(y_axis, str):
            y_axis = self[y_axis]

        if isinstance(axis_metric, (list, tuple)):
            x_axis_metric = axis_metric[0]
            y_axis_metric = axis_metric[1]
        else:
            x_axis_metric = axis_metric
            y_axis_metric = axis_metric

        # Determine axes values and labels
        if isinstance(x_axis, int):
            x_val = self.to_X()[:, x_axis]
            x_lab = "Dimension " + str(x_axis)
        else:
            x_axis_metric = Embedding._get_plot_axis_metric_callable(
                x_axis_metric)
            x_val = self.compare_against(x_axis, mapping=x_axis_metric)
            x_lab = x_axis.name

        if isinstance(y_axis, int):
            y_val = self.to_X()[:, y_axis]
            y_lab = "Dimension " + str(y_axis)
        else:
            y_axis_metric = Embedding._get_plot_axis_metric_callable(
                y_axis_metric)
            y_val = self.compare_against(y_axis, mapping=y_axis_metric)
            y_lab = y_axis.name
        x_label = x_label if x_label is not None else x_lab
        y_label = y_label if y_label is not None else y_lab
        title = title if title is not None else f"{x_lab} vs. {y_lab}"

        plot_df = pd.DataFrame({
            "x_axis":
            x_val,
            "y_axis":
            y_val,
            "name": [v.name for v in self.embeddings.values()],
            "original": [v.orig for v in self.embeddings.values()],
        })

        if color:
            plot_df[color] = [
                getattr(v, color) if hasattr(v, color) else ""
                for v in self.embeddings.values()
            ]

        result = (alt.Chart(plot_df).mark_circle(size=60).encode(
            x=alt.X("x_axis", axis=alt.Axis(title=x_label)),
            y=alt.X("y_axis", axis=alt.Axis(title=y_label)),
            tooltip=["name", "original"],
            color=alt.Color(":N", legend=None)
            if not color else alt.Color(color),
        ).properties(title=title).interactive())

        if annot:
            text = (alt.Chart(plot_df).mark_text(dx=-15, dy=3,
                                                 color="black").encode(
                                                     x="x_axis",
                                                     y="y_axis",
                                                     text="original",
                                                 ))
            result = result + text
        return result