コード例 #1
0
    def test_load_or_cache_amazon_with_alternative_capitalization(
        self,
        mock_load_amazon,
    ):
        mock_load_amazon.return_value = pd.DataFrame({
            "description": ["X", "Y"],  # Must be named description for Amazon
            "label": [1, 2],
            "category": ["cat1", "cat2"],
        })

        # Call load_or_cache_data.  Should ignore capitalization
        load_or_cache_data(FAKE_DIR, "amAzOn", with_cache=False)
コード例 #2
0
    def test_on_the_fly(self):
        # Test should only be run if the necessary files already exist.
        self._assert_files_exist()

        # Load dataset
        dataset = load_or_cache_data(DATADIR, DATASET_NAME)

        # Load w2v embeddings
        w2v_model = load_word_vector_model(small=True, cache_dir=W2VDIR)
        vocab_w2v_embeddings, vocab = get_topk_w2v_vectors(w2v_model, k=VOCAB_SIZE)
        vocab_w2v_embeddings = to_tensor(vocab_w2v_embeddings)

        # Load SBERT embeddings
        vocab_sbert_filename = fewshot_filename(
            W2VDIR, f"sbert_embeddings_for_{VOCAB_SIZE}_words.pt"
        )
        cached_data = torch_load(vocab_sbert_filename)
        vocab_sbert_embeddings = cached_data["embeddings"]

        # Calculate linear map of best fit between maps.
        Zmap = OLS_with_l2_regularization(
            vocab_sbert_embeddings, vocab_w2v_embeddings
        )

        # Predict and score
        score, predictions = predict_and_score(dataset, linear_maps=[Zmap], return_predictions=True)
        score3 = simple_topk_accuracy(dataset.labels, predictions)

        self.assertAlmostEqual(score, 65.5657894736842)
        self.assertAlmostEqual(score3, 96.01315789473685)
コード例 #3
0
def load_examples(data_name="agnews"):
    if data_name not in ["agnews", "reddit"]:
        print("Dataset name not found!")
        return

    dataset = load_or_cache_data(DATADIR + "/" + data_name, data_name)

    if data_name == "agnews":
        # cherry-picked example indexes
        #example_idx = [142, 811, 1201, 1440, 1767, 1788]
        example_idx = [200, 1582, 2754, 3546, 3825, 5129, 6574]
        titles = [
            "Strong Family Equals Strong Education",
            "Hurricane Ivan Batters Grand Cayman",
            "Supernova Warning System Will Give Astronomers Earlier Notice",
            "Study: Few Americans Buy Drugs Online",
            "Red Sox Feeling Heat of 0-2 Start in ALCS",
            "Product Previews palmOneUpgrades Treo With Faster Chip, Better Display",
            "Is this the end of IT as we know it? "
        ]

    examples = {}
    title_to_idx = {}
    for i, idx in enumerate(example_idx):
        text = dataset.examples[idx]
        title = titles[i]
        examples[title] = text
        title_to_idx[title] = idx

    return examples, title_to_idx, dataset
コード例 #4
0
    def test_category_sorting(
        self,
        mock_load_amazon,
    ):
        mock_load_amazon.return_value = pd.DataFrame({
            "description": ["A", "B", "C", "D", "E"],
            "label": [3, 1, 2, 1, 3],
            "category": ["cat3", "cat1", "cat2", "cat1", "cat3"],
        })

        expected_dataset = Dataset(
            examples=["A", "B", "C", "D", "E"],
            labels=[3, 1, 2, 1, 3],
            # Must go in order of label.
            categories=["cat1", "cat2", "cat3"],
        )

        # Call load_or_cache_data.  Capitalization of "AmaZon" is ignored.
        load_or_cache_data(FAKE_DIR, "amazon", with_cache=False)
コード例 #5
0
    def test_load_or_cache_amazon(
        self,
        mock_load_amazon,
    ):
        mock_load_amazon.return_value = pd.DataFrame({
            "description": ["X", "Y"],  # Must be named description for Amazon
            "label": [1, 2],
            "category": ["cat1", "cat2"],
        })

        expected_dataset = Dataset(examples=["X", "Y"],
                                   labels=[1, 2],
                                   categories=["cat1", "cat2"])

        # Call load_or_cache_data.
        self.assertEqual(
            load_or_cache_data(FAKE_DIR, "amazon", with_cache=False),
            expected_dataset)
コード例 #6
0
    def test_load_or_cache_reddit(
        self,
        mock_load_reddit,
    ):
        mock_load_reddit.return_value = pd.DataFrame({
            "summary": ["X", "Y"],  # Must be named summary for reddit
            "label": [1, 2],
            "category": ["cat1", "cat2"],
        })

        expected_dataset = Dataset(
            examples=["X", "Y"],
            labels=[1, 2],
            categories=["cat1", "cat2"],
        )

        # Call load_or_cache_data.
        self.assertEqual(
            load_or_cache_data(FAKE_DIR, "reddit", with_cache=False),
            expected_dataset)
コード例 #7
0
    def test_load_or_cache_agnews(
        self,
        mock_load_agnews,
    ):
        mock_load_agnews.return_value = pd.DataFrame({
            "text": ["X", "Y"],  # Must be named text for AGNews
            "label": [1, 2],
            "category": ["cat1", "cat2"],
        })

        expected_dataset = Dataset(
            examples=["X", "Y"],
            labels=[1, 2],
            categories=["cat1", "cat2"],
        )

        # Call load_or_cache_data.
        self.assertEqual(
            load_or_cache_data(FAKE_DIR, "agnews", with_cache=False),
            expected_dataset)
コード例 #8
0
fewshot_model = FewShotLinearRegression(
    Zmap.size()[1],
    Zmap.size()[1],
    loss_fcn=BayesianMSELoss(device=device),
    lr=learning_rate,
    device=device)
# train!
loss_history = train(fewshot_model,
                     data_loader,
                     num_epochs=num_epochs,
                     lam=lambda_regularization)

# after training we can extract Wmap (the weights of the linear model)
Wmap = fewshot_model.linear.weight.detach().cpu()

## Test
# Wmap learns to associate training examples to their associated labels
# We can now apply Wmap to the test set

# load the test set
test_dataset = load_or_cache_data(DATADIR, DATASET_NAME)

score = predict_and_score(test_dataset,
                          linear_maps=[Zmap, Wmap],
                          return_predictions=False)
print(score)

## Success!
# Let's save this Wmap
torch_save(Wmap, fewshot_filename(f"data/maps/Wmap_{DATASET_NAME}.pt"))