Example #1
0
 def test_get_datasets(self):
     prepared = pandas.read_csv(str(TEST_DATA_DIR / "prepared_data.csv.xz"),
                                index_col=0,
                                keep_default_na=False)
     config = {
         "train_size": 1000,
         "test_size": 100,
         "typo_probability": 0.5,
         "add_typo_probability": 0.05,
         "train_path": None,
         "test_path": None,
     }
     train, test = get_datasets(prepared, config)
     self.assertTrue({
         Columns.Token, Columns.CorrectToken, Columns.Split,
         Columns.CorrectSplit
     }.issubset(set(train.columns)))
     corrupted = sum(train[Columns.Token] != train[Columns.CorrectToken])
     self.assertEqual(corrupted, 500)
     self.assertEqual(len(train), 1000)
     self.assertEqual(len(test), 100)
     print(
         len(
             set(train[Columns.Token]).intersection(set(
                 test[Columns.Token]))))
Example #2
0
def cli_get_datasets(data_path: str, config: Mapping[str, Any]) -> None:
    """Entry point for `get_datasets`."""
    get_datasets(
        pandas.read_csv(data_path, index_col=0, keep_default_na=False), config)