Пример #1
0
def _create_dataset_from_df(df, text_column: str, filename: str = None):
    dataset = Dataset(
        examples=df[text_column].tolist(),
        labels=df.label.tolist(),
        categories=_prepare_category_names(df),
    )

    if filename is not None:
        dataset.calc_sbert_embeddings()
        pickle_save(dataset, filename)
    return dataset
Пример #2
0
    def test_load_or_cache_amazon(
        self,
        mock_load_amazon,
    ):
        mock_load_amazon.return_value = pd.DataFrame({
            "description": ["X", "Y"],  # Must be named description for Amazon
            "label": [1, 2],
            "category": ["cat1", "cat2"],
        })

        expected_dataset = Dataset(examples=["X", "Y"],
                                   labels=[1, 2],
                                   categories=["cat1", "cat2"])

        # Call load_or_cache_data.
        self.assertEqual(
            load_or_cache_data(FAKE_DIR, "amazon", with_cache=False),
            expected_dataset)
Пример #3
0
    def test_category_sorting(
        self,
        mock_load_amazon,
    ):
        mock_load_amazon.return_value = pd.DataFrame({
            "description": ["A", "B", "C", "D", "E"],
            "label": [3, 1, 2, 1, 3],
            "category": ["cat3", "cat1", "cat2", "cat1", "cat3"],
        })

        expected_dataset = Dataset(
            examples=["A", "B", "C", "D", "E"],
            labels=[3, 1, 2, 1, 3],
            # Must go in order of label.
            categories=["cat1", "cat2", "cat3"],
        )

        # Call load_or_cache_data.  Capitalization of "AmaZon" is ignored.
        load_or_cache_data(FAKE_DIR, "amazon", with_cache=False)
Пример #4
0
    def test_load_or_cache_reddit(
        self,
        mock_load_reddit,
    ):
        mock_load_reddit.return_value = pd.DataFrame({
            "summary": ["X", "Y"],  # Must be named summary for reddit
            "label": [1, 2],
            "category": ["cat1", "cat2"],
        })

        expected_dataset = Dataset(
            examples=["X", "Y"],
            labels=[1, 2],
            categories=["cat1", "cat2"],
        )

        # Call load_or_cache_data.
        self.assertEqual(
            load_or_cache_data(FAKE_DIR, "reddit", with_cache=False),
            expected_dataset)
Пример #5
0
    def test_load_or_cache_agnews(
        self,
        mock_load_agnews,
    ):
        mock_load_agnews.return_value = pd.DataFrame({
            "text": ["X", "Y"],  # Must be named text for AGNews
            "label": [1, 2],
            "category": ["cat1", "cat2"],
        })

        expected_dataset = Dataset(
            examples=["X", "Y"],
            labels=[1, 2],
            categories=["cat1", "cat2"],
        )

        # Call load_or_cache_data.
        self.assertEqual(
            load_or_cache_data(FAKE_DIR, "agnews", with_cache=False),
            expected_dataset)