Пример #1
0
def _load_amazon_products_dataset(datadir: str,
                                  num_categories: int = 6) -> pd.DataFrame:
    """Load Amazon products dataset from AMAZON_SAMPLE_PATH."""
    df = pd.read_csv(fewshot_filename(datadir, AMAZON_SAMPLE_PATH))
    keepers = df["category"].value_counts()[:num_categories]
    df = df[df["category"].isin(keepers.index.tolist())]
    df["category"] = pd.Categorical(df.category)
    df["label"] = df.category.cat.codes
    return df
Пример #2
0
def _load_small_word_vector_model(cache_dir, num_most_common_words=500000):
    filename = fewshot_filename(cache_dir, W2V_SMALL)
    if not os.path.exists(filename):
        orig_model = _load_large_word_vector_model(cache_dir)
        words = orig_model.index2entity[:num_most_common_words]

        kv = KeyedVectors(vector_size=orig_model.wv.vector_size)

        vectors = []
        for word in words:
            vectors.append(orig_model.get_vector(word))

        # adds keys (words) & vectors as batch
        kv.add(words, vectors)

        w2v_small_filename = fewshot_filename(cache_dir, W2V_SMALL)
        kv.save_word2vec_format(w2v_small_filename, binary=True)

    return KeyedVectors.load_word2vec_format(filename, binary=True)
Пример #3
0
def _load_large_word_vector_model(cache_dir):
    filename = fewshot_filename(cache_dir, ORIGINAL_W2V)
    if not os.path.exists(filename):
        print("No Word2Vec vectors not found. Downloading...")
        url = "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz"
        r = requests.get(url, allow_redirects=True)
        create_path(filename)
        with open(filename, "wb") as f:
            f.write(r.content)

    return KeyedVectors.load_word2vec_format(filename, binary=True)
Пример #4
0
def load_word_vector_model(small=True, cache_dir=W2VDIR):
    # TODO: be able to load GloVe or Word2Vec embedding model
    # TODO: make a smaller version that only has, say, top 100k words
    if small:
        filename = fewshot_filename(cache_dir, W2V_SMALL)
    else:
        filename = fewshot_filename(cache_dir, ORIGINAL_W2V)

    if not os.path.exists(filename):
        original_filename = fewshot_filename(cache_dir, ORIGINAL_W2V)
        if not os.path.exists(original_filename):
            print("No Word2Vec vectors not found. Downloading...")
            url = "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz"
            r = requests.get(url, allow_redirects=True)
            create_path(original_filename)
            open(original_filename, "wb").write(r.content)

        if small:
            create_small_w2v_model(cache_dir=cache_dir)

    model = KeyedVectors.load_word2vec_format(filename, binary=True)
    return model
Пример #5
0
    def test_load_or_cache_sbert_embeddings_picks_right_dataset(
        self,
        test_name,
        input_data_name,
        target_data_name,
        mock_exists,
        mock_load_agnews,
        mock_load_amazon,
        mock_model_tokenizer,
        mock_get_embeddings,
    ):
        # Test-level constants
        FAKE_DIR = "FAKE_DIR"
        AMAZON_WORDS = ["amazon", "words"]
        AGNEWS_WORDS = ["agnews", "words"]
        OUTPUT = 123  # Doesn't resemble actual output.

        # Mock values
        mock_exists.return_value = False

        mock_load_amazon.return_value = AMAZON_WORDS
        mock_load_agnews.return_value = AGNEWS_WORDS

        # Don't use these return values because we mock.
        mock_model_tokenizer.return_value = (None, None)

        mock_get_embeddings.return_value = OUTPUT

        # Call load_or_cache_sbert_embeddings
        self.assertEqual(
            load_or_cache_sbert_embeddings(FAKE_DIR, input_data_name), OUTPUT)

        # Expect functions are called with expected values.
        expected_filename = fewshot_filename(
            FAKE_DIR, f"{target_data_name}_embeddings.pt")
        mock_exists.assert_called_once_with(expected_filename)

        if target_data_name == "amazon":
            mock_get_embeddings.assert_called_once_with(
                AMAZON_WORDS,
                AnyObj(),
                AnyObj(),
                output_filename=expected_filename,
            )
        if target_data_name == "agnews":
            mock_get_embeddings.assert_called_once_with(
                AGNEWS_WORDS,
                AnyObj(),
                AnyObj(),
                output_filename=expected_filename,
            )
Пример #6
0
def create_small_w2v_model(num_most_common_words=500000, cache_dir=W2VDIR):
    orig_model = load_word_vector_model(small=False, cache_dir=cache_dir)
    words = orig_model.index2entity[:num_most_common_words]

    kv = KeyedVectors(vector_size=orig_model.wv.vector_size)

    vectors = []
    for word in words:
        vectors.append(orig_model.get_vector(word))

    # adds keys (words) & vectors as batch
    kv.add(words, vectors)

    w2v_small_filename = fewshot_filename(cache_dir, W2V_SMALL)
    kv.save_word2vec_format(w2v_small_filename, binary=True)
Пример #7
0
def load_or_cache_data(datadir: str,
                       dataset_name: str,
                       with_cache: bool = True) -> Dataset:
    """Loads sbert embeddings.

    First checks for a cached computation, otherwise builds the embedding with a
    call to get_transformer_embeddings using the specified dataset and standard
    model and tokenizer.

    Args:
        datadir: Where to save/load cached files.
        dataset_name: "amazon", "agnews", or "reddit".
        with_cache: If set, use cache files.  Settable for testing.

    Raises:
        ValueError: If an unexpected dataset_name is passed.

    Returns:
        The embeddings.
    """
    # Check for cached data.
    print("Checking for cached data...")
    dataset_name = dataset_name.lower()
    filename = None
    if with_cache:
        filename = fewshot_filename(datadir, f"{dataset_name}_dataset.pt")
        if os.path.exists(filename):
            return pickle_load(filename)

    print(f"{dataset_name} dataset not found. Computing...")
    # Load appropriate data
    if dataset_name == "amazon":
        df = _load_amazon_products_dataset(datadir)
        text_column, category_column = "description", "category"
    elif dataset_name == "agnews":
        df = _load_agnews_dataset()
        text_column, category_column = "text", "category"
    elif dataset_name == "reddit":
        df = _load_reddit_dataset(datadir)
        text_column, category_column = "summary", "category"
    else:
        raise ValueError(f"Unexpected dataset name: {dataset_name}.\n \
                          Please choose from: agnews, amazon, or reddit")

    dataset = _create_dataset_from_df(df, text_column, filename=filename)
    return dataset
Пример #8
0
    def test_load_or_cache_sbert_embeddings_picks_right_dataset(
            self, mock_exists):
        # Test-level constants
        FAKE_DIR = "FAKE_DIR"
        bad_name = "bad_name"

        # Mock value
        mock_exists.return_value = False

        # Call load_or_cache_sbert_embeddings
        with self.assertRaisesRegex(ValueError,
                                    f"Unexpected dataset name: {bad_name}"):
            load_or_cache_sbert_embeddings(FAKE_DIR, bad_name)

        # Expect functions are called with expected values.
        mock_exists.assert_called_once_with(
            fewshot_filename(FAKE_DIR, f"{bad_name}_embeddings.pt"))
Пример #9
0
def _load_reddit_dataset(datadir: str,
                         categories: str = "curated") -> pd.DataFrame:
    # TODO: the dataset included with the repo no longer allows a choice between
    #       "curated" or "top10" -- curated subreddits only; should update this here
    #       and remove that functionality
    """
    Load a curated and smaller version of the Reddit dataset from dataset library.

    There are two dataset options to choose from:
        1. (default) "curated" categories returns reddit examples from popular subreddits
            that have more meaningful subreddit names
        2. "top10" categories returns reddit examples from the most popular
            subreddits regardless of how meaningful the subreddit name is
        3. Anything else will return all the possible categories (16 in total)
    """
    df = pd.read_csv(fewshot_filename(datadir, REDDIT_SAMPLE_PATH))
    curated_subreddits = [
        "relationships",
        "trees",
        "gaming",
        "funny",
        "politics",
        "sex",
        "Fitness",
        "worldnews",
        "personalfinance",
        "technology",
    ]
    top10_subreddits = df["category"].value_counts()[:10]

    if categories == "curated":
        df = df[df["subreddit"].isin(curated_subreddits)]
    elif categories == "top10":
        df = df[df["subreddit"].isin(top10_subreddits.index.tolist())]
    df["category"] = pd.Categorical(df.category)
    df["label"] = df.category.cat.codes
    return df
Пример #10
0
                                     prepare_dataloader, train)

DATASET_NAME = "agnews"
DATADIR = f"data/{DATASET_NAME}"

## Load Training Data

# This loads all 120k examples in the AG News training set
df_news_train = _load_agnews_dataset(split="train")

# We want to explore learning from a limited number of training samples so we select
# a subsample containing just 400 examples (100 from each of the 4 categories).
df_news_train_subset = select_subsample(df_news_train, sample_size=100)

# convert that DataFrame to a Dataset
ds_filename = fewshot_filename(f"{DATADIR}/{DATASET_NAME}_train_dataset.pkl")
if os.path.exists(ds_filename):
    news_train_subset = pickle_load(ds_filename)
else:
    news_train_subset = _create_dataset_from_df(df_news_train_subset,
                                                text_column='text',
                                                filename=ds_filename)
# this is required due the particular implementation details of our Dataset class
news_train_subset = expand_labels(news_train_subset)

## Load Zmap
# We'll proceed under the assumption that the Zmap we learned during on-the-fly
# classification provides the best representations for our text and labels.
Zmap = torch.load(fewshot_filename("data/maps/Zmap_20000_words.pt"))

## Prepare a Torch DataLoader for training
Пример #11
0
# and then perform classification with cosine similarity as before

# Load the w2v embedding model
w2v_model = load_word_vector_model(small=True, cache_dir=W2VDIR)
VOCAB_SIZE = 20000

import pdb

pdb.set_trace()
# We found that using a vocabulary size of 20,000 words is good for most applications
vocab_w2v_embeddings, vocab = get_topk_w2v_vectors(w2v_model, k=VOCAB_SIZE)
vocab_w2v_embeddings = to_tensor(vocab_w2v_embeddings)

# Passing 20k words through SBERT can be time-consuming, even with a GPU.
# Fortunately, we've already performed this step and include precomputed embeddings.
vocab_sbert_filename = fewshot_filename(
    W2VDIR, f"sbert_embeddings_for_{VOCAB_SIZE}_words.pt")

if os.path.exists(vocab_sbert_filename):
    cached_data = torch_load(vocab_sbert_filename)
    vocab_sbert_embeddings = cached_data["embeddings"]
else:
    model, tokenizer = load_transformer_model_and_tokenizer()
    vocab_sbert_embeddings = get_transformer_embeddings(
        vocab, model, tokenizer, output_filename=vocab_sbert_filename)

# Perform ordinary least-squares linear regression to learn Zmap
Zmap = OLS_with_l2_regularization(vocab_sbert_embeddings, vocab_w2v_embeddings)

score, predictions = predict_and_score(dataset,
                                       linear_maps=[Zmap],
                                       return_predictions=True)
Пример #12
0
            color=color,
            markeredgecolor="k",
        )

    plt.axis("off")

    st.pyplot(fig=fig)


EXAMPLES, title_to_idx, dataset = load_examples("agnews")
LABELS = dataset.categories

MAPPINGS = load_linear_maps()

### ------- SIDEBAR ------- ###
image = Image.open(fewshot_filename(IMAGEDIR, "cloudera-fast-forward.png"))
st.sidebar.image(image, use_column_width=True)
st.sidebar.markdown(
    "This prototype accompanies our [Few-Shot Text Classification](LINK) report in which we\
     explore how text embeddings can be used for few- and zero-shot text classification."
)
st.sidebar.markdown(
    "In this technique, the text and each of the labels is embedded into the same embedding space.\
    The text is then assigned the label whose embedding is most similar to the text's embedding."
)
st.sidebar.markdown("")

st.sidebar.markdown("#### Model")
# TODO: Add other model options?
st.sidebar.markdown(
    "Text and label embeddings are first computed with **Sentence-BERT**. Used alone,\