def test_split(): sd = StructuredDataset(df=df, label_names=['label'], protected_attribute_names=['two']) train, test = sd.split([0.5]) train2, test2 = sd.split(2) assert train == train2 assert test == test2 assert np.all(np.concatenate((train.features, test.features)) == sd.features)
def test_k_folds(): sd = StructuredDataset(df=df, label_names=['label'], protected_attribute_names=['two']) folds = sd.split(4) assert len(folds) == 4 assert all(f.features.shape[0] == f.labels.shape[0] == f.protected_attributes.shape[0] == len(f.instance_names) == f.instance_weights.shape[0] == 1 for f in folds) folds = sd.split(3) assert folds[0].features.shape[0] == 2