Ejemplo n.º 1
0
    def __post_init__(self):
        if not len(self.y_true) == self.y_pred_proba.shape[0]:
            raise ValueError(
                "y_true and y_pred_proba must have the same number of observations"
            )

        self.multilabel = is_multilabel(self.y_true)
Ejemplo n.º 2
0
 def __post_init__(self):
     self.evaluation = ClassificationEvaluation(
         labels=self.labels,
         X=self.X,
         y_true=self.y_true,
         y_pred_proba=self.y_pred_proba,
         metric_funcs=self.metric_funcs,
     )
     self.multilabel = is_multilabel(self.y_true)
Ejemplo n.º 3
0
def get_label_indices(y: Union[List[str], List[List[str]]]) -> Dict[str, List[int]]:
    label_indices = defaultdict(list)
    if is_multilabel(y):
        for i, labels in enumerate(y):
            for label in labels:
                label_indices[label].append(i)
    else:
        for i, cls in enumerate(y):
            label_indices[cls].append(i)
    return label_indices
Ejemplo n.º 4
0
    def __post_init__(self):
        self.multilabel = is_multilabel(self.y_train)

        for X in (self.X_train, self.X_valid):
            validate_X(X)

        for y in self.y_train, self.y_valid:
            validate_multilabel_y(y, self.multilabel)

        for X, y in ((self.X_train, self.y_train), (self.X_valid,
                                                    self.y_valid)):
            validate_X_y(X, y)
Ejemplo n.º 5
0
def _collect_label_counts(
        labels: Union[List[str], List[List[str]]]) -> pd.DataFrame:
    if is_multilabel(labels):
        counts = defaultdict(int)
        for row_labels in labels:
            for label in row_labels:
                counts[label] += 1
        label_counts = pd.Series(counts) / len(labels)
    else:
        label_counts = pd.Series(labels).value_counts(normalize=True)
    label_counts = label_counts.to_frame(name="Proportion")
    label_counts.index.name = "Label"
    return label_counts.reset_index()
Ejemplo n.º 6
0
def _show_example_documents(
    texts: List[str],
    labels: Optional[Union[List[str], List[List[str]]]],
    truncate_len: int,
):
    df = pd.DataFrame(
        {"Document": [truncate_text(t, truncate_len) for t in texts]})
    if labels is not None:
        if is_multilabel(labels):
            label_col = "Labels"
        else:
            label_col = "Label"
        df[label_col] = labels
    st.table(df)
Ejemplo n.º 7
0
def show_embeddings(
    model_cls: Any,
    model_kwargs: Dict[str, Any],
    texts: List[str],
    labels: Optional[Union[List[str], List[List[str]]]],
    checkpoint_meta: Optional[Dict[str, Any]] = None,
    batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
    umap_seed: int = 1,
    umap_n_neighbors: int = 15,
    umap_metric: str = "euclidean",
    umap_min_dist: float = 0.1,
    cluster_when: str = "before",
    clusterer: Optional[Any] = None,
    show_vocab_overlap: bool = False,
):
    if cluster_when not in ("before", "after"):
        raise ValueError(f"Unexpected cluster_when value: '{cluster_when}'")

    embeddings, _ = get_embeddings(
        model_cls,
        model_kwargs,
        texts,
        checkpoint_meta=checkpoint_meta,
        batch_size=batch_size,
    )
    X_embedded = pd.DataFrame(embeddings)

    umap = UMAP(
        n_neighbors=umap_n_neighbors,
        metric=umap_metric,
        min_dist=umap_min_dist,
        random_state=umap_seed,
    )

    clusters = None
    if clusterer is not None and cluster_when == "before":
        clusterer.fit(X_embedded)
        clusters = clusterer.labels_

    umap_data = umap.fit_transform(X_embedded)

    if clusterer is not None and cluster_when == "after":
        clusterer.fit(umap_data)
        clusters = clusterer.labels_

    umap_df = pd.DataFrame(umap_data,
                           columns=["UMAP Component 1", "UMAP Component 2"])
    tooltip_attrs = ["Text"]

    dataset_is_multilabel = labels is not None and is_multilabel(labels)
    label_col_name = "Labels" if dataset_is_multilabel else "Label"

    if labels is not None:
        tooltip_attrs.append(label_col_name)
        umap_df[label_col_name] = labels

    if clusters is not None:
        color_attr = "Cluster"
        tooltip_attrs.append("Cluster")
        umap_df["Cluster"] = clusters
        umap_df["Cluster"] = umap_df["Cluster"].astype(str)
    elif labels is not None and not dataset_is_multilabel:
        # Coloring by label doesn't make sense for a multilabel dataset
        color_attr = label_col_name
    else:
        color_attr = None

    # NOTE: Altair (or the underlying charting library, vega-lite) will
    # truncate these texts before being displayed
    umap_df["Text"] = texts

    umap_chart = (alt.Chart(umap_df, height=700,
                            width=700).mark_circle(size=60).encode(
                                alt.X("UMAP Component 1",
                                      scale=alt.Scale(zero=False),
                                      axis=None),
                                alt.Y("UMAP Component 2",
                                      scale=alt.Scale(zero=False),
                                      axis=None),
                                tooltip=alt.Tooltip(tooltip_attrs),
                            ))
    if color_attr is not None:
        umap_chart = umap_chart.encode(alt.Color(color_attr))

    st.altair_chart(umap_chart)

    if show_vocab_overlap:
        missing_token_counts = defaultdict(int)
        oov_count = 0
        total_count = 0
        try:
            embeddings, embed_tokens = get_embeddings(
                model_cls,
                model_kwargs,
                texts,
                checkpoint_meta=checkpoint_meta,
                batch_size=batch_size,
                pooling=EmbedPooling.NONE,
            )
        except ValueError:
            st.error(
                "This model doesn't support generating token-level embeddings, "
                "so the vocabulary overlap report can't be calculated.")
            return

        for doc_embedding, doc_tokens in zip(embeddings, embed_tokens):
            oov_embeddings = np.abs(doc_embedding).sum(axis=1) == 0
            oov_count += oov_embeddings.sum()
            total_count += len(doc_tokens)
            for oov_token_ndx in np.argwhere(oov_embeddings).ravel():
                missing_token_counts[doc_tokens[oov_token_ndx]] += 1

        st.subheader("Vocabulary Overlap")
        num_oov_tokens_show = 20
        st.markdown(f"""
            - Total number of tokens: {total_count:,}
            - Number of out-of-vocabulary tokens: {oov_count:,} ({(oov_count / total_count) * 100:.2f}%)
            """)
        if oov_count > 0:
            st.markdown(f"""
                - Top {num_oov_tokens_show} most frequent out-of-vocabulary tokens:
                """)
            oov_df = pd.DataFrame.from_dict(
                {
                    tok: count
                    for tok, count in sorted(missing_token_counts.items(),
                                             key=lambda kv: -kv[1])
                },
                orient="index",
            ).tail(num_oov_tokens_show)
            oov_df.columns = ["Out-of-Vocabulary Token Count"]
            st.dataframe(oov_df)