Example #1
0
    def plot_interactive_matrix(
        self,
        *axes: Union[int, str, Embedding],
        axes_metric: Optional[Union[str, Callable, Sequence]] = None,
        annot: bool = True,
        width: int = 200,
        height: int = 200,
    ):
        """
        Makes highly interactive plot of the set of embeddings.

        Arguments:
            axes: the axes that we wish to plot; each could be either an integer, the name of
                an existing embedding, or an `Embedding` instance (default: `0, 1`).
            axes_metric: the metric used to project each embedding on the axes; only used when the corresponding
                axis is a string or an `Embedding` instance. It could be a string (`'cosine_similarity'`,
                `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input and
                returns a scalar value as output. To set different metrics for different axes, a list or a tuple of
                the same length as `axes` could be given. By default (`None`), normalized scalar projection (i.e. `>` operator) is used.
            annot: drawn points should be annotated
            width: width of the visual
            height: height of the visual

        **Usage**

        ```python
        from whatlies.language import SpacyLanguage
        from whatlies.transformers import Pca

        words = ["prince", "princess", "nurse", "doctor", "banker", "man", "woman",
                 "cousin", "neice", "king", "queen", "dude", "guy", "gal", "fire",
                 "dog", "cat", "mouse", "red", "bluee", "green", "yellow", "water",
                 "person", "family", "brother", "sister"]

        lang = SpacyLanguage("en_core_web_sm")
        emb = lang[words]

        emb.transform(Pca(3)).plot_interactive_matrix(0, 1, 2)
        ```
        """
        # Set default value of axes, if not given.
        if len(axes) == 0:
            axes = [0, 1]

        if isinstance(axes_metric,
                      (list, tuple)) and len(axes_metric) != len(axes):
            raise ValueError(
                f"The number of given axes metrics should be the same as the number of given axes. Got {len(axes)} axes vs. {len(axes_metric)} metrics."
            )
        if not isinstance(axes_metric, (list, tuple)):
            axes_metric = [axes_metric] * len(axes)

        # Get values of each axis according to their type.
        axes_vals = {}
        X = self.to_X()
        for axis, metric in zip(axes, axes_metric):
            if isinstance(axis, int):
                vals = X[:, axis]
                axes_vals["Dimension " + str(axis)] = vals
            else:
                if isinstance(axis, str):
                    axis = self[axis]
                metric = Embedding._get_plot_axis_metric_callable(metric)
                vals = self.compare_against(axis, mapping=metric)
                axes_vals[axis.name] = vals

        plot_df = pd.DataFrame(axes_vals)
        plot_df["name"] = [v.name for v in self.embeddings.values()]
        plot_df["original"] = [v.orig for v in self.embeddings.values()]
        axes_names = list(axes_vals.keys())

        result = (alt.Chart(plot_df).mark_circle().encode(
            x=alt.X(alt.repeat("column"), type="quantitative"),
            y=alt.Y(alt.repeat("row"), type="quantitative"),
            tooltip=["name", "original"],
        ))
        if annot:
            text_stuff = result.mark_text(dx=-15, dy=3, color="black").encode(
                text="original", )
            result = result + text_stuff

        result = (result.properties(width=width, height=height).repeat(
            row=axes_names[::-1], column=axes_names).interactive())

        return result
Example #2
0
    def plot_3d(
        self,
        x_axis: Union[int, str, Embedding] = 0,
        y_axis: Union[int, str, Embedding] = 1,
        z_axis: Union[int, str, Embedding] = 2,
        x_label: Optional[str] = None,
        y_label: Optional[str] = None,
        z_label: Optional[str] = None,
        title: Optional[str] = None,
        color: str = None,
        axis_metric: Optional[Union[str, Callable, Sequence]] = None,
        annot: bool = True,
    ):
        """
        Creates a 3d visualisation of the embedding.

        Arguments:
            x_axis: the x-axis to be used, must be given when dim > 3; if an integer, the corresponding
                dimension of embedding is used.
            y_axis: the y-axis to be used, must be given when dim > 3; if an integer, the corresponding
                dimension of embedding is used.
            z_axis: the z-axis to be used, must be given when dim > 3; if an integer, the corresponding
                dimension of embedding is used.
            x_label: an optional label used for x-axis; if not given, it is set based on value of `x_axis`.
            y_label: an optional label used for y-axis; if not given, it is set based on value of `y_axis`.
            z_label: an optional label used for z-axis; if not given, it is set based on value of `z_axis`.
            title: an optional title for the plot.
            color: the property to user for the color
            axis_metric: the metric used to project each embedding on the axes; only used when the corresponding
                axis is a string or an `Embedding` instance. It could be a string (`'cosine_similarity'`,
                `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input and
                returns a scalar value as output. To set different metrics of the three different axes,
                you can pass a list/tuple of size three that describes the metrics you're interested in.
                By default (`None`), normalized scalar projection (i.e. `>` operator) is used.
            annot: drawn points should be annotated

        **Usage**

        ```python
        from whatlies.language import SpacyLanguage
        from whatlies.transformers import Pca

        words = ["prince", "princess", "nurse", "doctor", "banker", "man", "woman",
                 "cousin", "neice", "king", "queen", "dude", "guy", "gal", "fire",
                 "dog", "cat", "mouse", "red", "bluee", "green", "yellow", "water",
                 "person", "family", "brother", "sister"]

        lang = SpacyLanguage("en_core_web_sm")
        emb = lang[words]

        emb.transform(Pca(3)).plot_3d(annot=True)
        emb.transform(Pca(3)).plot_3d("king", "dog", "red")
        emb.transform(Pca(3)).plot_3d("king", "dog", "red", axis_metric="cosine_distance")
        ```
        """
        if isinstance(x_axis, str):
            x_axis = self[x_axis]
        if isinstance(y_axis, str):
            y_axis = self[y_axis]
        if isinstance(z_axis, str):
            z_axis = self[z_axis]

        if isinstance(axis_metric, (list, tuple)):
            x_axis_metric = axis_metric[0]
            y_axis_metric = axis_metric[1]
            z_axis_metric = axis_metric[2]
        else:
            x_axis_metric = axis_metric
            y_axis_metric = axis_metric
            z_axis_metric = axis_metric

        # Determine axes values and labels
        if isinstance(x_axis, int):
            x_val = self.to_X()[:, x_axis]
            x_lab = "Dimension " + str(x_axis)
        else:
            x_axis_metric = Embedding._get_plot_axis_metric_callable(
                x_axis_metric)
            x_val = self.compare_against(x_axis, mapping=x_axis_metric)
            x_lab = x_axis.name
        x_lab = x_label if x_label is not None else x_lab

        if isinstance(y_axis, int):
            y_val = self.to_X()[:, y_axis]
            y_lab = "Dimension " + str(y_axis)
        else:
            y_axis_metric = Embedding._get_plot_axis_metric_callable(
                y_axis_metric)
            y_val = self.compare_against(y_axis, mapping=y_axis_metric)
            y_lab = y_axis.name
        y_lab = y_label if y_label is not None else y_lab

        if isinstance(z_axis, int):
            z_val = self.to_X()[:, z_axis]
            z_lab = "Dimension " + str(z_axis)
        else:
            z_axis_metric = Embedding._get_plot_axis_metric_callable(
                z_axis_metric)
            z_val = self.compare_against(z_axis, mapping=z_axis_metric)
            z_lab = z_axis.name
        z_lab = z_label if z_label is not None else z_lab

        # Save relevant information in a dataframe for plotting later.
        plot_df = pd.DataFrame({
            "x_axis":
            x_val,
            "y_axis":
            y_val,
            "z_axis":
            z_val,
            "name": [v.name for v in self.embeddings.values()],
            "original": [v.orig for v in self.embeddings.values()],
        })

        # Deal with the colors of the dots.
        if color:
            plot_df["color"] = [
                getattr(v, color) if hasattr(v, color) else ""
                for v in self.embeddings.values()
            ]

            color_map = {k: v for v, k in enumerate(set(plot_df["color"]))}
            color_val = [
                color_map[k] if not isinstance(k, float) else k
                for k in plot_df["color"]
            ]
        else:
            color_val = None

        ax = plt.axes(projection="3d")
        ax.scatter3D(plot_df["x_axis"],
                     plot_df["y_axis"],
                     plot_df["z_axis"],
                     c=color_val,
                     s=25)

        # Set the labels, titles, text annotations.
        ax.set_xlabel(x_lab)
        ax.set_ylabel(y_lab)
        ax.set_zlabel(z_lab)

        if annot:
            for i, row in plot_df.iterrows():
                ax.text(row["x_axis"], row["y_axis"], row["z_axis"] + 0.05,
                        row["original"])
        if title:
            ax.set_title(label=title)
        return ax
Example #3
0
    def plot_interactive(
        self,
        x_axis: Union[int, str, Embedding] = 0,
        y_axis: Union[int, str, Embedding] = 1,
        axis_metric: Optional[Union[str, Callable, Sequence]] = None,
        x_label: Optional[str] = None,
        y_label: Optional[str] = None,
        title: Optional[str] = None,
        annot: bool = True,
        color: Union[None, str] = None,
    ):
        """
        Makes highly interactive plot of the set of embeddings.

        Arguments:
            x_axis: the x-axis to be used, must be given when dim > 2; if an integer, the corresponding
                dimension of embedding is used.
            y_axis: the y-axis to be used, must be given when dim > 2; if an integer, the corresponding
                dimension of embedding is used.
            axis_metric: the metric used to project each embedding on the axes; only used when the corresponding
                axis (i.e. `x_axis` or `y_axis`) is a string or an `Embedding` instance. It could be a string
                (`'cosine_similarity'`, `'cosine_distance'` or `'euclidean'`), or a callable that takes two vectors as input
                and returns a scalar value as output. To set different metrics for x- and y-axis, a list or a tuple of
                two elements could be given. By default (`None`), normalized scalar projection (i.e. `>` operator) is used.
            x_label: an optional label used for x-axis; if not given, it is set based on `x_axis` value.
            y_label: an optional label used for y-axis; if not given, it is set based on `y_axis` value.
            title: an optional title for the plot; if not given, it is set based on `x_axis` and `y_axis` values.
            annot: drawn points should be annotated
            color: a property that will be used for plotting

        **Usage**

        ```python
        from whatlies.language import SpacyLanguage

        words = ["prince", "princess", "nurse", "doctor", "banker", "man", "woman",
                 "cousin", "neice", "king", "queen", "dude", "guy", "gal", "fire",
                 "dog", "cat", "mouse", "red", "bluee", "green", "yellow", "water",
                 "person", "family", "brother", "sister"]

        lang = SpacyLanguage("en_core_web_sm")
        emb = lang[words]

        emb.plot_interactive('man', 'woman')
        ```
        """
        if isinstance(x_axis, str):
            x_axis = self[x_axis]
        if isinstance(y_axis, str):
            y_axis = self[y_axis]

        if isinstance(axis_metric, (list, tuple)):
            x_axis_metric = axis_metric[0]
            y_axis_metric = axis_metric[1]
        else:
            x_axis_metric = axis_metric
            y_axis_metric = axis_metric

        # Determine axes values and labels
        if isinstance(x_axis, int):
            x_val = self.to_X()[:, x_axis]
            x_lab = "Dimension " + str(x_axis)
        else:
            x_axis_metric = Embedding._get_plot_axis_metric_callable(
                x_axis_metric)
            x_val = self.compare_against(x_axis, mapping=x_axis_metric)
            x_lab = x_axis.name

        if isinstance(y_axis, int):
            y_val = self.to_X()[:, y_axis]
            y_lab = "Dimension " + str(y_axis)
        else:
            y_axis_metric = Embedding._get_plot_axis_metric_callable(
                y_axis_metric)
            y_val = self.compare_against(y_axis, mapping=y_axis_metric)
            y_lab = y_axis.name
        x_label = x_label if x_label is not None else x_lab
        y_label = y_label if y_label is not None else y_lab
        title = title if title is not None else f"{x_lab} vs. {y_lab}"

        plot_df = pd.DataFrame({
            "x_axis":
            x_val,
            "y_axis":
            y_val,
            "name": [v.name for v in self.embeddings.values()],
            "original": [v.orig for v in self.embeddings.values()],
        })

        if color:
            plot_df[color] = [
                getattr(v, color) if hasattr(v, color) else ""
                for v in self.embeddings.values()
            ]

        result = (alt.Chart(plot_df).mark_circle(size=60).encode(
            x=alt.X("x_axis", axis=alt.Axis(title=x_label)),
            y=alt.X("y_axis", axis=alt.Axis(title=y_label)),
            tooltip=["name", "original"],
            color=alt.Color(":N", legend=None)
            if not color else alt.Color(color),
        ).properties(title=title).interactive())

        if annot:
            text = (alt.Chart(plot_df).mark_text(dx=-15, dy=3,
                                                 color="black").encode(
                                                     x="x_axis",
                                                     y="y_axis",
                                                     text="original",
                                                 ))
            result = result + text
        return result