コード例 #1
0
    def test_on_the_fly(self):
        # Test should only be run if the necessary files already exist.
        self._assert_files_exist()

        # Load dataset
        dataset = load_or_cache_data(DATADIR, DATASET_NAME)

        # Load w2v embeddings
        w2v_model = load_word_vector_model(small=True, cache_dir=W2VDIR)
        vocab_w2v_embeddings, vocab = get_topk_w2v_vectors(w2v_model, k=VOCAB_SIZE)
        vocab_w2v_embeddings = to_tensor(vocab_w2v_embeddings)

        # Load SBERT embeddings
        vocab_sbert_filename = fewshot_filename(
            W2VDIR, f"sbert_embeddings_for_{VOCAB_SIZE}_words.pt"
        )
        cached_data = torch_load(vocab_sbert_filename)
        vocab_sbert_embeddings = cached_data["embeddings"]

        # Calculate linear map of best fit between maps.
        Zmap = OLS_with_l2_regularization(
            vocab_sbert_embeddings, vocab_w2v_embeddings
        )

        # Predict and score
        score, predictions = predict_and_score(dataset, linear_maps=[Zmap], return_predictions=True)
        score3 = simple_topk_accuracy(dataset.labels, predictions)

        self.assertAlmostEqual(score, 65.5657894736842)
        self.assertAlmostEqual(score3, 96.01315789473685)
コード例 #2
0
fewshot_model = FewShotLinearRegression(
    Zmap.size()[1],
    Zmap.size()[1],
    loss_fcn=BayesianMSELoss(device=device),
    lr=learning_rate,
    device=device)
# train!
loss_history = train(fewshot_model,
                     data_loader,
                     num_epochs=num_epochs,
                     lam=lambda_regularization)

# after training we can extract Wmap (the weights of the linear model)
Wmap = fewshot_model.linear.weight.detach().cpu()

## Test
# Wmap learns to associate training examples to their associated labels
# We can now apply Wmap to the test set

# load the test set
test_dataset = load_or_cache_data(DATADIR, DATASET_NAME)

score = predict_and_score(test_dataset,
                          linear_maps=[Zmap, Wmap],
                          return_predictions=False)
print(score)

## Success!
# Let's save this Wmap
torch_save(Wmap, fewshot_filename(f"data/maps/Wmap_{DATASET_NAME}.pt"))
コード例 #3
0
)

DATASET_NAME = "agnews"
DATADIR = f"data/{DATASET_NAME}"
W2VDIR = "data/w2v"
TOPK = 3

## Load data
# On first call, this function will download the agnews dataset from the
# HuggingFace Datasets repository, cache it, and then process it for use
# in this analysis
dataset = load_or_cache_data(DATADIR, DATASET_NAME)
# `dataset` is a specialized object containing the original text, the
# SentenceBERT embedding, and the label for each example in the test set.

score, predictions = predict_and_score(dataset, return_predictions=True)

score_intop3 = simple_topk_accuracy(dataset.labels, predictions)
print(f"Score: {score}")
print(f"Score considering the top {TOPK} best labels: {score_intop3}")
print()

## Let's make this model a bit better!

### Learn a mapping between SBERT sentence embeddings and Word2Vec word embeddings
# SBERT is optimized for sentences, and word2vec is optimized for words. To get
# the best of both worlds we'll learn a mapping between these two embeddings
# spaces that we can use during classification.

# To learn a mapping, we'll need to
# -- identify a large vocabulary, V, of popular words,