def load_train_test_dfs(local_cache_path="./",
                        test_fraction=0.5,
                        random_seed=None):
    """
    Get the training and testing data frames based on test_fraction.

    Args:
        local_cache_path (str): Path to store the data. If the data file
            doesn't exist in this path, it's downloaded.
        test_fraction (float, optional): Fraction of data ot use for
            testing. Since this is a small dataset, the default testing
            fraction is set to 0.5
        random_seed (float, optional): Random seed used to shuffle the data.

    Returns:
        tuple: (train_pandas_df, test_pandas_df), each data frame contains
            two columns
            "sentence": sentences in strings.
            "labels": list of entity labels of the words in the sentence.

    """
    file_name = URL.split("/")[-1]
    maybe_download(URL, file_name, local_cache_path)

    data_file = os.path.join(local_cache_path, file_name)

    with open(data_file, "r", encoding="utf8") as file:
        text = file.read()

    sentence_list, labels_list = preprocess_conll(text)

    if random_seed:
        random.seed(random_seed)
    sentence_and_labels = list(zip(sentence_list, labels_list))
    random.shuffle(sentence_and_labels)
    sentence_list[:], labels_list[:] = zip(*sentence_and_labels)

    sentence_count = len(sentence_list)
    test_sentence_count = round(sentence_count * test_fraction)
    test_sentence_list = sentence_list[:test_sentence_count]
    test_labels_list = labels_list[:test_sentence_count]
    train_sentence_list = sentence_list[test_sentence_count:]
    train_labels_list = labels_list[test_sentence_count:]

    train_df = pd.DataFrame({
        "sentence": train_sentence_list,
        "labels": train_labels_list
    })

    test_df = pd.DataFrame({
        "sentence": test_sentence_list,
        "labels": test_labels_list
    })

    return (train_df, test_df)
Ejemplo n.º 2
0
def test_ner_utils(ner_utils_test_data):
    output = preprocess_conll(ner_utils_test_data["input"])
    assert output == ner_utils_test_data["expected_output"]