示例#1
0
def test_not_balance_model_data(model_data: RasaModelData):
    test_model_data = RasaModelData(
        label_key="entities", label_sub_key="tag_ids", data=model_data.data
    )

    data = test_model_data.balanced_data(test_model_data.data, 2, False)

    assert np.all(
        data["entities"]["tag_ids"] == test_model_data.get("entities", "tag_ids")
    )
示例#2
0
def test_balance_model_data(model_data: RasaModelData):
    data = model_data.balanced_data(model_data.data, 2, False)

    assert np.all(np.array(data["label"]["ids"][0]) == np.array([0, 1, 1, 0, 1]))