def test_get_datasets(self): prepared = pandas.read_csv(str(TEST_DATA_DIR / "prepared_data.csv.xz"), index_col=0, keep_default_na=False) config = { "train_size": 1000, "test_size": 100, "typo_probability": 0.5, "add_typo_probability": 0.05, "train_path": None, "test_path": None, } train, test = get_datasets(prepared, config) self.assertTrue({ Columns.Token, Columns.CorrectToken, Columns.Split, Columns.CorrectSplit }.issubset(set(train.columns))) corrupted = sum(train[Columns.Token] != train[Columns.CorrectToken]) self.assertEqual(corrupted, 500) self.assertEqual(len(train), 1000) self.assertEqual(len(test), 100) print( len( set(train[Columns.Token]).intersection(set( test[Columns.Token]))))
def cli_get_datasets(data_path: str, config: Mapping[str, Any]) -> None: """Entry point for `get_datasets`.""" get_datasets( pandas.read_csv(data_path, index_col=0, keep_default_na=False), config)